diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 41726c9fb..9c47abb30 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -35,7 +35,7 @@ from scaling import ( Identity, _diag, random_clamp, - with_loss, + penalize_abs_values_gt, ) from torch import Tensor, nn @@ -647,6 +647,11 @@ class AttentionDownsample(torch.nn.Module): src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) + + scores = penalize_abs_values_gt(scores, + limit=5.0, + penalty=1.0e-04) + weights = scores.softmax(dim=1) # ans1 is the first `in_channels` channels of the output @@ -1116,15 +1121,9 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_weights_limit = 25.0 - # caution: this penalty will be affected by grad-scaling in amp. - # It's OK; this is just an emergency brake, and under normal - # conditions it shouldn't be active - attn_weights_penalty = 1.0e-04 - aux_loss = attn_weights_penalty * (attn_output_weights.abs() - - attn_weights_limit).relu() - attn_output_weights = with_loss(attn_output_weights, - aux_loss) + attn_output_weights = penalize_abs_values_gt(attn_output_weights, + limit=25.0, + penalty=1.0e-04) # attn_output_weights: (batch, head, time1, time2) @@ -1545,22 +1544,22 @@ class AttentionCombine(nn.Module): (num_frames, num_channels, num_inputs) ) - weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias + scores = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias if random.random() < 0.002: - logging.info(f"Average weights are {weights.softmax(dim=1).mean(dim=0)}") + logging.info(f"Average scores are {scores.softmax(dim=1).mean(dim=0)}") if self.training: # random masking.. mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=weights.device).unsqueeze(1) + size=(num_frames,), device=scores.device).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand( + arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( num_frames, num_inputs) mask = arange >= mask_start apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), - device=weights.device) < self.single_prob, + device=scores.device) < self.single_prob, mask_start < num_inputs) single_prob_mask = torch.logical_and(apply_single_prob, arange < mask_start - 1) @@ -1568,8 +1567,13 @@ class AttentionCombine(nn.Module): mask = torch.logical_or(mask, single_prob_mask) - weights = weights.masked_fill(mask, float('-inf')) - weights = weights.softmax(dim=1) + scores = scores.masked_fill(mask, float('-inf')) + + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) + + weights = scores.softmax(dim=1) # (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1), ans = torch.matmul(stacked_inputs, weights.unsqueeze(2)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ab8a12bf0..5b4fece3d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -545,6 +545,24 @@ class ActivationBalancer(torch.nn.Module): return _no_op(x) +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + """ + aux_loss = penalty * (x.abs() - limit).relu() + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss) + # you must use x for something, or this will be ineffective. + return x + def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. if x.ndim == 2: