Penalize too large weights in softmax of AttentionDownsample()

This commit is contained in:
Daniel Povey 2022-10-21 20:12:36 +08:00
parent 9f68b5717c
commit bdbd2cfce6
2 changed files with 39 additions and 17 deletions

View File

@ -35,7 +35,7 @@ from scaling import (
Identity,
_diag,
random_clamp,
with_loss,
penalize_abs_values_gt,
)
from torch import Tensor, nn
@ -647,6 +647,11 @@ class AttentionDownsample(torch.nn.Module):
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
scores = (src * self.query).sum(dim=-1, keepdim=True)
scores = penalize_abs_values_gt(scores,
limit=5.0,
penalty=1.0e-04)
weights = scores.softmax(dim=1)
# ans1 is the first `in_channels` channels of the output
@ -1116,15 +1121,9 @@ class RelPositionMultiheadAttention(nn.Module):
# this mechanism instead of, say, a limit on entropy, because once the entropy
# gets very small gradients through the softmax can become very small, and
# some mechanisms like that become ineffective.
attn_weights_limit = 25.0
# caution: this penalty will be affected by grad-scaling in amp.
# It's OK; this is just an emergency brake, and under normal
# conditions it shouldn't be active
attn_weights_penalty = 1.0e-04
aux_loss = attn_weights_penalty * (attn_output_weights.abs() -
attn_weights_limit).relu()
attn_output_weights = with_loss(attn_output_weights,
aux_loss)
attn_output_weights = penalize_abs_values_gt(attn_output_weights,
limit=25.0,
penalty=1.0e-04)
# attn_output_weights: (batch, head, time1, time2)
@ -1545,22 +1544,22 @@ class AttentionCombine(nn.Module):
(num_frames, num_channels, num_inputs)
)
weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
scores = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
if random.random() < 0.002:
logging.info(f"Average weights are {weights.softmax(dim=1).mean(dim=0)}")
logging.info(f"Average scores are {scores.softmax(dim=1).mean(dim=0)}")
if self.training:
# random masking..
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
size=(num_frames,), device=weights.device).unsqueeze(1)
size=(num_frames,), device=scores.device).unsqueeze(1)
# mask will have rows like: [ False, False, False, True, True, .. ]
arange = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand(
num_frames, num_inputs)
mask = arange >= mask_start
apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1),
device=weights.device) < self.single_prob,
device=scores.device) < self.single_prob,
mask_start < num_inputs)
single_prob_mask = torch.logical_and(apply_single_prob,
arange < mask_start - 1)
@ -1568,8 +1567,13 @@ class AttentionCombine(nn.Module):
mask = torch.logical_or(mask,
single_prob_mask)
weights = weights.masked_fill(mask, float('-inf'))
weights = weights.softmax(dim=1)
scores = scores.masked_fill(mask, float('-inf'))
scores = penalize_abs_values_gt(scores,
limit=10.0,
penalty=1.0e-04)
weights = scores.softmax(dim=1)
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))

View File

@ -545,6 +545,24 @@ class ActivationBalancer(torch.nn.Module):
return _no_op(x)
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
"""
Returns x unmodified, but in backprop will put a penalty for the excess of
the absolute values of elements of x over the limit "limit". E.g. if
limit == 10.0, then if x has any values over 10 it will get a penalty.
Caution: the value of this penalty will be affected by grad scaling used
in automatic mixed precision training. For this reasons we use this,
it shouldn't really matter, or may even be helpful; we just use this
to disallow really implausible values of scores to be given to softmax.
"""
aux_loss = penalty * (x.abs() - limit).relu()
# note: we don't do sum() here on aux)_loss, but it's as if we had done
# sum() due to how with_loss() works.
x = with_loss(x, aux_loss)
# you must use x for something, or this will be ineffective.
return x
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
if x.ndim == 2: