From 2bbc63a2f5831197cc623627787010aee44c99df Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Jun 2022 23:33:16 +0800 Subject: [PATCH] Change first1k to decay1k --- .../ASR/pruned_transducer_stateless2/scaling.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 2d1b83e33..2a6e173e1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -841,7 +841,8 @@ class Decorrelate(torch.nn.Module): Args: num_channels: The number of channels, e.g. 256. - apply_steps: The number of steps for which we apply this penalty. + apply_prob_decay: The probability with which we apply this each time, in + training mode, will decay as apply_prob_decay/(apply_prob_decay + step). scale: This number determines the scale of the gradient contribution from this module, relative to whatever the gradient was before; this is applied per frame or pixel, by scaling gradients. @@ -854,13 +855,13 @@ class Decorrelate(torch.nn.Module): def __init__(self, num_channels: int, scale: float = 0.1, - apply_steps: int = 1000, + apply_prob_decay: int = 1000, eps: float = 1.0e-05, beta: float = 0.95, channel_dim: int = -1): super(Decorrelate, self).__init__() self.scale = scale - self.apply_steps = apply_steps + self.apply_prob_decay = apply_prob_decay self.eps = eps self.beta = beta self.channel_dim = channel_dim @@ -878,11 +879,14 @@ class Decorrelate(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - if not self.training or self.step >= self.apply_steps: + if not self.training: return x else: + apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay) self.step += 1 self.step_buf.fill_(float(self.step)) + if random.random() > apply_prob: + return x with torch.cuda.amp.autocast(enabled=False): x = x.to(torch.float32) # the function updates self.cov in its backward pass (it needs the gradient