Make it run only for first 3k steps; larger scale; remove limit of 1.0

This commit is contained in:
Daniel Povey 2022-06-10 16:33:22 +08:00
parent eeb95ed502
commit 86c2d0fcc0
2 changed files with 6 additions and 16 deletions

View File

@ -847,13 +847,7 @@ class DecorrelateFunction(torch.autograd.Function):
decorr_x_grad = x.grad decorr_x_grad = x.grad
# loss.detach().clamp(min=0.0, max=1.0) is a factor that means once scale = ctx.scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5
# the loss starts getting quite small (less than 1), we start using
# smaller derivatives.
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
scale = decorr_loss_scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5
decorr_x_grad = decorr_x_grad * scale decorr_x_grad = decorr_x_grad * scale
x_grad = x_grad + decorr_x_grad x_grad = x_grad + decorr_x_grad
@ -874,8 +868,7 @@ class Decorrelate(torch.nn.Module):
Args: Args:
num_channels: The number of channels, e.g. 256. num_channels: The number of channels, e.g. 256.
apply_prob_decay: The probability with which we apply this each time, in apply_steps: The number of steps for which we apply this penalty.
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
scale: This number determines the scale of the gradient contribution from scale: This number determines the scale of the gradient contribution from
this module, relative to whatever the gradient was before; this module, relative to whatever the gradient was before;
this is applied per frame or pixel, by scaling gradients. this is applied per frame or pixel, by scaling gradients.
@ -888,13 +881,13 @@ class Decorrelate(torch.nn.Module):
def __init__(self, def __init__(self,
num_channels: int, num_channels: int,
scale: float = 0.1, scale: float = 0.1,
apply_prob_decay: int = 1000, apply_steps: int = 3000,
eps: float = 1.0e-05, eps: float = 1.0e-05,
beta: float = 0.95, beta: float = 0.95,
channel_dim: int = -1): channel_dim: int = -1):
super(Decorrelate, self).__init__() super(Decorrelate, self).__init__()
self.scale = scale self.scale = scale
self.apply_prob_decay = apply_prob_decay self.apply_steps = apply_steps
self.eps = eps self.eps = eps
self.beta = beta self.beta = beta
self.channel_dim = channel_dim self.channel_dim = channel_dim
@ -912,14 +905,11 @@ class Decorrelate(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if not self.training: if not self.training or self.step >= self.apply_steps:
return x return x
else: else:
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
self.step += 1 self.step += 1
self.step_buf.fill_(float(self.step)) self.step_buf.fill_(float(self.step))
if random.random() > apply_prob:
return x
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x = x.to(torch.float32) x = x.to(torch.float32)
# the function updates self.cov in its backward pass (it needs the gradient # the function updates self.cov in its backward pass (it needs the gradient

View File

@ -94,7 +94,7 @@ class Conformer(EncoderInterface):
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
) )
self.decorrelate = Decorrelate(d_model, scale=0.05) self.decorrelate = Decorrelate(d_model, scale=0.1)
def forward( def forward(