diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 05eef1554..8dcfdddc7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -847,13 +847,7 @@ class DecorrelateFunction(torch.autograd.Function): decorr_x_grad = x.grad - # loss.detach().clamp(min=0.0, max=1.0) is a factor that means once - # the loss starts getting quite small (less than 1), we start using - # smaller derivatives. - - - decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0) - scale = decorr_loss_scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5 + scale = ctx.scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5 decorr_x_grad = decorr_x_grad * scale x_grad = x_grad + decorr_x_grad @@ -874,8 +868,7 @@ class Decorrelate(torch.nn.Module): Args: num_channels: The number of channels, e.g. 256. - apply_prob_decay: The probability with which we apply this each time, in - training mode, will decay as apply_prob_decay/(apply_prob_decay + step). + apply_steps: The number of steps for which we apply this penalty. scale: This number determines the scale of the gradient contribution from this module, relative to whatever the gradient was before; this is applied per frame or pixel, by scaling gradients. @@ -888,13 +881,13 @@ class Decorrelate(torch.nn.Module): def __init__(self, num_channels: int, scale: float = 0.1, - apply_prob_decay: int = 1000, + apply_steps: int = 3000, eps: float = 1.0e-05, beta: float = 0.95, channel_dim: int = -1): super(Decorrelate, self).__init__() self.scale = scale - self.apply_prob_decay = apply_prob_decay + self.apply_steps = apply_steps self.eps = eps self.beta = beta self.channel_dim = channel_dim @@ -912,14 +905,11 @@ class Decorrelate(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - if not self.training: + if not self.training or self.step >= self.apply_steps: return x else: - apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay) self.step += 1 self.step_buf.fill_(float(self.step)) - if random.random() > apply_prob: - return x with torch.cuda.amp.autocast(enabled=False): x = x.to(torch.float32) # the function updates self.cov in its backward pass (it needs the gradient diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 66334688b..fba44ec17 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -94,7 +94,7 @@ class Conformer(EncoderInterface): aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), ) - self.decorrelate = Decorrelate(d_model, scale=0.05) + self.decorrelate = Decorrelate(d_model, scale=0.1) def forward(