Merge branch 'scaled_adam_exp168' into scaled_adam_exp169

This commit is contained in:
Daniel Povey 2022-10-22 14:05:07 +08:00
commit 9672dffac2
2 changed files with 47 additions and 18 deletions

View File

@ -35,7 +35,7 @@ from scaling import (
Identity,
_diag,
random_clamp,
with_loss,
penalize_abs_values_gt,
)
from torch import Tensor, nn
@ -647,6 +647,11 @@ class AttentionDownsample(torch.nn.Module):
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
scores = (src * self.query).sum(dim=-1, keepdim=True)
scores = penalize_abs_values_gt(scores,
limit=5.0,
penalty=1.0e-04)
weights = scores.softmax(dim=1)
# ans1 is the first `in_channels` channels of the output
@ -717,7 +722,13 @@ class SimpleCombiner(torch.nn.Module):
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
weight = (weight1 + weight2).sigmoid()
logit = (weight1 + weight2)
if self.training and random.random() < 0.1:
logit = penalize_abs_values_gt(logit,
limit=25.0,
penalty=1.0e-04)
weight = logit.sigmoid()
src2_part1 = src2[...,:dim1]
part1 = src1 * weight + src2_part1 * (1.0 - weight)
@ -1116,15 +1127,9 @@ class RelPositionMultiheadAttention(nn.Module):
# this mechanism instead of, say, a limit on entropy, because once the entropy
# gets very small gradients through the softmax can become very small, and
# some mechanisms like that become ineffective.
attn_weights_limit = 50.0
# caution: this penalty will be affected by grad-scaling in amp.
# It's OK; this is just an emergency brake, and under normal
# conditions it shouldn't be active
attn_weights_penalty = 1.0e-04
aux_loss = attn_weights_penalty * (attn_output_weights.abs() -
attn_weights_limit).relu()
attn_output_weights = with_loss(attn_output_weights,
aux_loss)
attn_output_weights = penalize_abs_values_gt(attn_output_weights,
limit=25.0,
penalty=1.0e-04)
# attn_output_weights: (batch, head, time1, time2)
@ -1545,22 +1550,22 @@ class AttentionCombine(nn.Module):
(num_frames, num_channels, num_inputs)
)
weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
scores = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
if random.random() < 0.002:
logging.info(f"Average weights are {weights.softmax(dim=1).mean(dim=0)}")
logging.info(f"Average scores are {scores.softmax(dim=1).mean(dim=0)}")
if self.training:
# random masking..
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
size=(num_frames,), device=weights.device).unsqueeze(1)
size=(num_frames,), device=scores.device).unsqueeze(1)
# mask will have rows like: [ False, False, False, True, True, .. ]
arange = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand(
num_frames, num_inputs)
mask = arange >= mask_start
apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1),
device=weights.device) < self.single_prob,
device=scores.device) < self.single_prob,
mask_start < num_inputs)
single_prob_mask = torch.logical_and(apply_single_prob,
arange < mask_start - 1)
@ -1568,8 +1573,14 @@ class AttentionCombine(nn.Module):
mask = torch.logical_or(mask,
single_prob_mask)
weights = weights.masked_fill(mask, float('-inf'))
weights = weights.softmax(dim=1)
scores = scores.masked_fill(mask, float('-inf'))
if self.training and random.random() < 0.1:
scores = penalize_abs_values_gt(scores,
limit=10.0,
penalty=1.0e-04)
weights = scores.softmax(dim=1)
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))

View File

@ -545,6 +545,24 @@ class ActivationBalancer(torch.nn.Module):
return _no_op(x)
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
"""
Returns x unmodified, but in backprop will put a penalty for the excess of
the absolute values of elements of x over the limit "limit". E.g. if
limit == 10.0, then if x has any values over 10 it will get a penalty.
Caution: the value of this penalty will be affected by grad scaling used
in automatic mixed precision training. For this reasons we use this,
it shouldn't really matter, or may even be helpful; we just use this
to disallow really implausible values of scores to be given to softmax.
"""
aux_loss = penalty * (x.abs() - limit).relu()
# note: we don't do sum() here on aux)_loss, but it's as if we had done
# sum() due to how with_loss() works.
x = with_loss(x, aux_loss)
# you must use x for something, or this will be ineffective.
return x
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
if x.ndim == 2: