mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp663' into scaled_adam_exp665
This commit is contained in:
commit
0fc646f281
@ -711,7 +711,8 @@ class ActivationBalancer(torch.nn.Module):
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
||||
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
|
||||
name: str = None) -> Tensor:
|
||||
"""
|
||||
Returns x unmodified, but in backprop will put a penalty for the excess of
|
||||
the absolute values of elements of x over the limit "limit". E.g. if
|
||||
@ -721,6 +722,8 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
||||
in automatic mixed precision training. For this reasons we use this,
|
||||
it shouldn't really matter, or may even be helpful; we just use this
|
||||
to disallow really implausible values of scores to be given to softmax.
|
||||
|
||||
The name is for randomly printed debug info.
|
||||
"""
|
||||
x_sign = x.sign()
|
||||
over_limit = (x.abs() - limit) > 0
|
||||
@ -734,7 +737,7 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
||||
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
||||
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
||||
# sum() due to how with_loss() works.
|
||||
x = with_loss(x, aux_loss)
|
||||
x = with_loss(x, aux_loss, name)
|
||||
# you must use x for something, or this will be ineffective.
|
||||
return x
|
||||
|
||||
@ -907,17 +910,23 @@ class Whiten(nn.Module):
|
||||
|
||||
class WithLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, y: Tensor):
|
||||
ctx.y_shape = y.shape
|
||||
def forward(ctx, x: Tensor, y: Tensor, name: str):
|
||||
ctx.name = name
|
||||
ctx.save_for_backward(y) # just for printing the name, and the shape
|
||||
return x
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
return ans_grad, torch.ones(ctx.y_shape,
|
||||
y, = ctx.saved_tensors
|
||||
if random.random() < 0.002 and ctx.name is not None:
|
||||
loss_sum = y.sum().item()
|
||||
logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}")
|
||||
|
||||
return ans_grad, torch.ones(y.shape,
|
||||
dtype=ans_grad.dtype,
|
||||
device=ans_grad.device)
|
||||
def with_loss(x, y):
|
||||
device=ans_grad.device), None
|
||||
def with_loss(x, y, name):
|
||||
# returns x but adds y.sum() to the loss function.
|
||||
return WithLoss.apply(x, y)
|
||||
return WithLoss.apply(x, y, name)
|
||||
|
||||
|
||||
class ScaleGradFunction(torch.autograd.Function):
|
||||
|
||||
@ -801,6 +801,8 @@ class AttentionDownsample(torch.nn.Module):
|
||||
super(AttentionDownsample, self).__init__()
|
||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||
|
||||
self.name = None # will be set from training code
|
||||
|
||||
# fill in the extra dimensions with a projection of the input
|
||||
if out_channels > in_channels:
|
||||
self.extra_proj = nn.Linear(in_channels * downsample,
|
||||
@ -833,8 +835,9 @@ class AttentionDownsample(torch.nn.Module):
|
||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=10.0,
|
||||
penalty=1.0e-04)
|
||||
limit=20.0,
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
|
||||
weights = scores.softmax(dim=1)
|
||||
|
||||
@ -1207,7 +1210,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# under normal circumstances.
|
||||
attn_scores = penalize_abs_values_gt(attn_scores,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04)
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
|
||||
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
@ -1870,6 +1874,7 @@ class AttentionCombine(nn.Module):
|
||||
num_inputs))
|
||||
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
|
||||
|
||||
self.name = None # will be set from training code
|
||||
assert 0 <= random_prob <= 1, random_prob
|
||||
assert 0 <= single_prob <= 1, single_prob
|
||||
|
||||
@ -1926,7 +1931,8 @@ class AttentionCombine(nn.Module):
|
||||
if self.training and random.random() < 0.1:
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=10.0,
|
||||
penalty=1.0e-04)
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
|
||||
weights = scores.softmax(dim=1)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user