Merge branch 'scaled_adam_exp663' into scaled_adam_exp665

This commit is contained in:
Daniel Povey 2022-12-10 00:07:37 +08:00
commit 0fc646f281
2 changed files with 27 additions and 12 deletions

View File

@ -711,7 +711,8 @@ class ActivationBalancer(torch.nn.Module):
return _no_op(x) return _no_op(x)
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
name: str = None) -> Tensor:
""" """
Returns x unmodified, but in backprop will put a penalty for the excess of Returns x unmodified, but in backprop will put a penalty for the excess of
the absolute values of elements of x over the limit "limit". E.g. if the absolute values of elements of x over the limit "limit". E.g. if
@ -721,6 +722,8 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
in automatic mixed precision training. For this reasons we use this, in automatic mixed precision training. For this reasons we use this,
it shouldn't really matter, or may even be helpful; we just use this it shouldn't really matter, or may even be helpful; we just use this
to disallow really implausible values of scores to be given to softmax. to disallow really implausible values of scores to be given to softmax.
The name is for randomly printed debug info.
""" """
x_sign = x.sign() x_sign = x.sign()
over_limit = (x.abs() - limit) > 0 over_limit = (x.abs() - limit) > 0
@ -734,7 +737,7 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
# note: we don't do sum() here on aux)_loss, but it's as if we had done # note: we don't do sum() here on aux)_loss, but it's as if we had done
# sum() due to how with_loss() works. # sum() due to how with_loss() works.
x = with_loss(x, aux_loss) x = with_loss(x, aux_loss, name)
# you must use x for something, or this will be ineffective. # you must use x for something, or this will be ineffective.
return x return x
@ -907,17 +910,23 @@ class Whiten(nn.Module):
class WithLoss(torch.autograd.Function): class WithLoss(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, y: Tensor): def forward(ctx, x: Tensor, y: Tensor, name: str):
ctx.y_shape = y.shape ctx.name = name
ctx.save_for_backward(y) # just for printing the name, and the shape
return x return x
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
return ans_grad, torch.ones(ctx.y_shape, y, = ctx.saved_tensors
if random.random() < 0.002 and ctx.name is not None:
loss_sum = y.sum().item()
logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}")
return ans_grad, torch.ones(y.shape,
dtype=ans_grad.dtype, dtype=ans_grad.dtype,
device=ans_grad.device) device=ans_grad.device), None
def with_loss(x, y): def with_loss(x, y, name):
# returns x but adds y.sum() to the loss function. # returns x but adds y.sum() to the loss function.
return WithLoss.apply(x, y) return WithLoss.apply(x, y, name)
class ScaleGradFunction(torch.autograd.Function): class ScaleGradFunction(torch.autograd.Function):

View File

@ -801,6 +801,8 @@ class AttentionDownsample(torch.nn.Module):
super(AttentionDownsample, self).__init__() super(AttentionDownsample, self).__init__()
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
self.name = None # will be set from training code
# fill in the extra dimensions with a projection of the input # fill in the extra dimensions with a projection of the input
if out_channels > in_channels: if out_channels > in_channels:
self.extra_proj = nn.Linear(in_channels * downsample, self.extra_proj = nn.Linear(in_channels * downsample,
@ -833,8 +835,9 @@ class AttentionDownsample(torch.nn.Module):
scores = (src * self.query).sum(dim=-1, keepdim=True) scores = (src * self.query).sum(dim=-1, keepdim=True)
scores = penalize_abs_values_gt(scores, scores = penalize_abs_values_gt(scores,
limit=10.0, limit=20.0,
penalty=1.0e-04) penalty=1.0e-04,
name=self.name)
weights = scores.softmax(dim=1) weights = scores.softmax(dim=1)
@ -1207,7 +1210,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# under normal circumstances. # under normal circumstances.
attn_scores = penalize_abs_values_gt(attn_scores, attn_scores = penalize_abs_values_gt(attn_scores,
limit=25.0, limit=25.0,
penalty=1.0e-04) penalty=1.0e-04,
name=self.name)
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
@ -1870,6 +1874,7 @@ class AttentionCombine(nn.Module):
num_inputs)) num_inputs))
self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
self.name = None # will be set from training code
assert 0 <= random_prob <= 1, random_prob assert 0 <= random_prob <= 1, random_prob
assert 0 <= single_prob <= 1, single_prob assert 0 <= single_prob <= 1, single_prob
@ -1926,7 +1931,8 @@ class AttentionCombine(nn.Module):
if self.training and random.random() < 0.1: if self.training and random.random() < 0.1:
scores = penalize_abs_values_gt(scores, scores = penalize_abs_values_gt(scores,
limit=10.0, limit=10.0,
penalty=1.0e-04) penalty=1.0e-04,
name=self.name)
weights = scores.softmax(dim=1) weights = scores.softmax(dim=1)