diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9f6180eab..9da5bb3b6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -34,7 +34,8 @@ from scaling import ( Whiten, Identity, _diag, - random_clamp + random_clamp, + with_loss, ) from torch import Tensor, nn @@ -1110,16 +1111,37 @@ class RelPositionMultiheadAttention(nn.Module): storage_offset=pos_weights.stride(3) * (seq_len - 1)) + # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights + # The following is a soft way of encouraging the attention scores to not be too large; + # in training time, once they get outside a certain range, -5.0..5.0 currently, we + # randomly either leave them as-is or truncate them to that range. if attn_weights_max is not None: attn_output_weights = random_clamp(attn_output_weights, min=-attn_weights_max, max=attn_weights_max, prob=0.5) - # attn_output_weights: (batch, head, time1, time2) + if training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_weights_limit = 50.0 + # caution: this penalty will be affected by grad-scaling in amp. + # It's OK; this is just an emergency brake, and under normal + # conditions it shouldn't be active + attn_weights_penalty = 1.0e-04 + aux_loss = attn_weights_penalty * (attn_output_weights.abs() - + attn_weights_limit).relu() + attn_output_weights = with_loss(attn_output_weights, + aux_loss) + + # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, seq_len, seq_len ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 87626b780..922ef052a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -608,6 +608,22 @@ class Whiten(nn.Module): self.whitening_limit, self.grad_scale) + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor): + ctx.y_shape = y.shape + return x + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ans_grad, torch.ones(ctx.y_shape, + dtype=ans_grad.dtype, + device=ans_grad.device) +def with_loss(x, y): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y) + + def _no_op(x: Tensor) -> Tensor: if (torch.jit.is_scripting()): return x @@ -617,7 +633,6 @@ def _no_op(x: Tensor) -> Tensor: return x.chunk(1, dim=-1)[0] - class Identity(torch.nn.Module): def __init__(self): super(Identity, self).__init__()