mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Make it run only for first 3k steps; larger scale; remove limit of 1.0
This commit is contained in:
parent
eeb95ed502
commit
86c2d0fcc0
@ -847,13 +847,7 @@ class DecorrelateFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
decorr_x_grad = x.grad
|
decorr_x_grad = x.grad
|
||||||
|
|
||||||
# loss.detach().clamp(min=0.0, max=1.0) is a factor that means once
|
scale = ctx.scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5
|
||||||
# the loss starts getting quite small (less than 1), we start using
|
|
||||||
# smaller derivatives.
|
|
||||||
|
|
||||||
|
|
||||||
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
|
|
||||||
scale = decorr_loss_scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5
|
|
||||||
decorr_x_grad = decorr_x_grad * scale
|
decorr_x_grad = decorr_x_grad * scale
|
||||||
|
|
||||||
x_grad = x_grad + decorr_x_grad
|
x_grad = x_grad + decorr_x_grad
|
||||||
@ -874,8 +868,7 @@ class Decorrelate(torch.nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_channels: The number of channels, e.g. 256.
|
num_channels: The number of channels, e.g. 256.
|
||||||
apply_prob_decay: The probability with which we apply this each time, in
|
apply_steps: The number of steps for which we apply this penalty.
|
||||||
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
|
|
||||||
scale: This number determines the scale of the gradient contribution from
|
scale: This number determines the scale of the gradient contribution from
|
||||||
this module, relative to whatever the gradient was before;
|
this module, relative to whatever the gradient was before;
|
||||||
this is applied per frame or pixel, by scaling gradients.
|
this is applied per frame or pixel, by scaling gradients.
|
||||||
@ -888,13 +881,13 @@ class Decorrelate(torch.nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
scale: float = 0.1,
|
scale: float = 0.1,
|
||||||
apply_prob_decay: int = 1000,
|
apply_steps: int = 3000,
|
||||||
eps: float = 1.0e-05,
|
eps: float = 1.0e-05,
|
||||||
beta: float = 0.95,
|
beta: float = 0.95,
|
||||||
channel_dim: int = -1):
|
channel_dim: int = -1):
|
||||||
super(Decorrelate, self).__init__()
|
super(Decorrelate, self).__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.apply_prob_decay = apply_prob_decay
|
self.apply_steps = apply_steps
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
@ -912,14 +905,11 @@ class Decorrelate(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if not self.training:
|
if not self.training or self.step >= self.apply_steps:
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.step_buf.fill_(float(self.step))
|
self.step_buf.fill_(float(self.step))
|
||||||
if random.random() > apply_prob:
|
|
||||||
return x
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
# the function updates self.cov in its backward pass (it needs the gradient
|
# the function updates self.cov in its backward pass (it needs the gradient
|
||||||
|
@ -94,7 +94,7 @@ class Conformer(EncoderInterface):
|
|||||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decorrelate = Decorrelate(d_model, scale=0.05)
|
self.decorrelate = Decorrelate(d_model, scale=0.1)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user