diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index d6840afab..6662c4edd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -746,7 +746,7 @@ class DecorrelateFunction(torch.autograd.Function): inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5 norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) - loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1 + loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels if random.random() < 0.01: logging.info(f"Decorrelate: loss = {loss}") loss.backward() @@ -758,7 +758,9 @@ class DecorrelateFunction(torch.autograd.Function): # `loss ** 0.5` times the magnitude of the original grad. x_grad_new_scale = (x_grad_new ** 2).sum(dim=1) x_grad_old_scale = (x_grad ** 2).sum(dim=1) - decorr_loss_scale = ctx.scale + + decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0) + scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5 x_grad_new = x_grad_new * scale.unsqueeze(-1) @@ -776,27 +778,56 @@ class Decorrelate(torch.nn.Module): This module does nothing in the forward pass, but in the backward pass, modifies the derivatives in such a way as to encourage the dimensions of its input to become decorrelated. + + Args: + num_channels: The number of channels, e.g. 256. + apply_prob_decay: The probability with which we apply this each time, in + training mode, will decay as apply_prob_decay/(apply_prob_decay + step). + scale: This number determines the scale of the gradient contribution from + this module, relative to whatever the gradient was before; + this is applied per frame or pixel, by scaling gradients. + eps: An epsilon used to prevent division by zero. + beta: A value 0 < beta < 1 that controls decay of covariance stats + channel_dim: The dimension of the input corresponding to the channel, e.g. + -1, 0, 1, 2. + """ def __init__(self, num_channels: int, scale: float = 0.1, + apply_prob_decay: int = 1000, eps: float = 1.0e-05, beta: float = 0.95, channel_dim: int = -1): super(Decorrelate, self).__init__() self.scale = scale + self.apply_prob_decay = apply_prob_decay self.eps = eps self.beta = beta self.channel_dim = channel_dim self.register_buffer('cov', torch.zeros(num_channels, num_channels)) + # step_buf is a copy of step, included so it will be loaded/saved with + # the model. + self.register_buffer('step_buf', torch.tensor(0)) self.step = 0 + + def load_state_dict(self, *args, **kwargs): + super(Decorrelate, self).load_state_dict(*args, **kwargs) + self.step = self.step_buf.item() + + def forward(self, x: Tensor) -> Tensor: if not self.training: return x else: + apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay) + self.step += 1 + self.step_buf.fill_(self.step) + if random.random() > apply_prob: + return x with torch.cuda.amp.autocast(enabled=False): ans = DecorrelateFunction.apply(x, self.cov.clone(), self.scale, self.eps, self.beta, @@ -807,7 +838,6 @@ class Decorrelate(torch.nn.Module): cov = torch.matmul(x.t(), x) with torch.no_grad(): self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) - self.step += 1 return ans # ans == x. @@ -825,9 +855,8 @@ class JoinDropout(torch.nn.Module): Args: num_channels: The number of channels, e.g. 256. - apply_prob: The probability with which we apply this each time, in - training mode. This is to save time (but of course it - will tend to make the effect weaker). + apply_prob_decay: The probability with which we apply this each time, in + training mode, will decay as apply_prob_decay/(apply_prob_decay + step). dropout_rate: This number determines the average dropout probability (it will actually vary across dimensions). eps: An epsilon used to prevent division by zero. @@ -925,7 +954,8 @@ class JoinDropout(torch.nn.Module): def forward(self, bypass: Tensor, x: Tensor) -> Tensor: - if not self.training or random.random() > self.apply_prob: + apply_prob = self.apply_prob + if not self.training or random.random() > apply_prob: return bypass + x else: x = x.transpose(self.channel_dim, -1) # (..., num_channels) @@ -1049,6 +1079,28 @@ def _test_gauss_proj_drop(): m1.eval() m2.eval() +def _test_decorrelate(): + D = 384 + x = torch.randn(30000, D) + # give it a non-unit covariance. + m = torch.randn(D, D) * (D ** -0.5) + _, S, _ = m.svd() + print("M eigs = ", S[::10]) + x = torch.matmul(x, m) + + + # check that class Decorrelate does not crash when running.. + decorrelate = Decorrelate(D) + x.requires_grad = True + y = decorrelate(x) + y.sum().backward() + + decorrelate2 = Decorrelate(D) + decorrelate2.load_state_dict(decorrelate.state_dict()) + assert decorrelate2.step == decorrelate.step + + + def _test_join_dropout(): D = 384 x = torch.randn(30000, D) @@ -1060,13 +1112,6 @@ def _test_join_dropout(): x = torch.matmul(x, m) - if True: - # check that class Decorrelate does not crash when running.. - decorrelate = Decorrelate(D) - x.requires_grad = True - y = decorrelate(x) - y.sum().backward() - for dropout_rate in [0.2, 0.1, 0.01, 0.05]: m1 = torch.nn.Dropout(dropout_rate) m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate) @@ -1089,6 +1134,7 @@ if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_decorrelate() _test_join_dropout() _test_gauss_proj_drop() _test_activation_balancer_sign() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 03f7fe6f5..0a85841fe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -198,10 +198,8 @@ class ConformerEncoderLayer(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 ) - self.dropout_ff_macaron = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) - self.dropout_conv = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) - self.dropout_self_attn = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) - self.dropout_ff = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) + self.dropout = nn.Dropout(dropout) + self.decorrelate = Decorrelate(d_model, scale=0.02) def forward( @@ -245,7 +243,7 @@ class ConformerEncoderLayer(nn.Module): alpha = 1.0 # macaron style feed forward module - src = self.dropout_ff_macaron(src, self.feed_forward_macaron(src)) + src = src + self.dropout(self.feed_forward_macaron(src)) # multi-headed self-attention module src_att = self.self_attn( @@ -256,16 +254,18 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = self.dropout_self_attn(src, src_att) + src = src + self.dropout(src_att) # convolution module - src = self.dropout_conv(src, self.conv_module(src)) + src = src + self.dropout(self.conv_module(src)) # feed forward module - src = self.dropout_ff(src, self.feed_forward(src)) + src = src + self.dropout(self.feed_forward(src)) src = self.norm_final(self.balancer(src)) + src = self.decorrelate(src) + if alpha != 1.0: src = alpha * src + (1 - alpha) * src_orig @@ -1032,7 +1032,6 @@ class Conv2dSubsampling(nn.Module): # itself has learned scale, so the extra degree of freedom is not # needed. self.out_norm = BasicNorm(out_channels, learn_eps=False) - self.decorrelate = Decorrelate(out_channels) # constrain median of output to be close to zero. self.out_balancer = ActivationBalancer( channel_dim=-1, min_positive=0.45, max_positive=0.55 @@ -1057,7 +1056,6 @@ class Conv2dSubsampling(nn.Module): x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.out_norm(x) - x = self.decorrelate(x) x = self.out_balancer(x) return x