Merge branch 'scaled_adam_exp663' into scaled_adam_exp665

This commit is contained in:
Daniel Povey 2022-12-10 00:07:37 +08:00
commit 0fc646f281
2 changed files with 27 additions and 12 deletions

View File

@ -711,7 +711,8 @@ class ActivationBalancer(torch.nn.Module):
return _no_op(x)
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
name: str = None) -> Tensor:
"""
Returns x unmodified, but in backprop will put a penalty for the excess of
the absolute values of elements of x over the limit "limit". E.g. if
@ -721,6 +722,8 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
in automatic mixed precision training. For this reasons we use this,
it shouldn't really matter, or may even be helpful; we just use this
to disallow really implausible values of scores to be given to softmax.
The name is for randomly printed debug info.
"""
x_sign = x.sign()
over_limit = (x.abs() - limit) > 0
@ -734,7 +737,7 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
# note: we don't do sum() here on aux)_loss, but it's as if we had done
# sum() due to how with_loss() works.
x = with_loss(x, aux_loss)
x = with_loss(x, aux_loss, name)
# you must use x for something, or this will be ineffective.
return x
@ -907,17 +910,23 @@ class Whiten(nn.Module):
class WithLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, y: Tensor):
ctx.y_shape = y.shape
def forward(ctx, x: Tensor, y: Tensor, name: str):
ctx.name = name
ctx.save_for_backward(y) # just for printing the name, and the shape
return x
@staticmethod
def backward(ctx, ans_grad: Tensor):
return ans_grad, torch.ones(ctx.y_shape,
y, = ctx.saved_tensors
if random.random() < 0.002 and ctx.name is not None:
loss_sum = y.sum().item()
logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}")
return ans_grad, torch.ones(y.shape,
dtype=ans_grad.dtype,
device=ans_grad.device)
def with_loss(x, y):
device=ans_grad.device), None
def with_loss(x, y, name):
# returns x but adds y.sum() to the loss function.
return WithLoss.apply(x, y)
return WithLoss.apply(x, y, name)
class ScaleGradFunction(torch.autograd.Function):

View File

@ -801,6 +801,8 @@ class AttentionDownsample(torch.nn.Module):
super(AttentionDownsample, self).__init__()
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
self.name = None # will be set from training code
# fill in the extra dimensions with a projection of the input
if out_channels > in_channels:
self.extra_proj = nn.Linear(in_channels * downsample,
@ -833,8 +835,9 @@ class AttentionDownsample(torch.nn.Module):
scores = (src * self.query).sum(dim=-1, keepdim=True)
scores = penalize_abs_values_gt(scores,
limit=10.0,
penalty=1.0e-04)
limit=20.0,
penalty=1.0e-04,
name=self.name)
weights = scores.softmax(dim=1)
@ -1207,7 +1210,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# under normal circumstances.
attn_scores = penalize_abs_values_gt(attn_scores,
limit=25.0,
penalty=1.0e-04)
penalty=1.0e-04,
name=self.name)
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
@ -1870,6 +1874,7 @@ class AttentionCombine(nn.Module):
num_inputs))
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
self.name = None # will be set from training code
assert 0 <= random_prob <= 1, random_prob
assert 0 <= single_prob <= 1, single_prob
@ -1926,7 +1931,8 @@ class AttentionCombine(nn.Module):
if self.training and random.random() < 0.1:
scores = penalize_abs_values_gt(scores,
limit=10.0,
penalty=1.0e-04)
penalty=1.0e-04,
name=self.name)
weights = scores.softmax(dim=1)