mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp663' into scaled_adam_exp665
This commit is contained in:
commit
0fc646f281
@ -711,7 +711,8 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
|
|
||||||
|
|
||||||
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
|
||||||
|
name: str = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns x unmodified, but in backprop will put a penalty for the excess of
|
Returns x unmodified, but in backprop will put a penalty for the excess of
|
||||||
the absolute values of elements of x over the limit "limit". E.g. if
|
the absolute values of elements of x over the limit "limit". E.g. if
|
||||||
@ -721,6 +722,8 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
|||||||
in automatic mixed precision training. For this reasons we use this,
|
in automatic mixed precision training. For this reasons we use this,
|
||||||
it shouldn't really matter, or may even be helpful; we just use this
|
it shouldn't really matter, or may even be helpful; we just use this
|
||||||
to disallow really implausible values of scores to be given to softmax.
|
to disallow really implausible values of scores to be given to softmax.
|
||||||
|
|
||||||
|
The name is for randomly printed debug info.
|
||||||
"""
|
"""
|
||||||
x_sign = x.sign()
|
x_sign = x.sign()
|
||||||
over_limit = (x.abs() - limit) > 0
|
over_limit = (x.abs() - limit) > 0
|
||||||
@ -734,7 +737,7 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
|||||||
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
||||||
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
||||||
# sum() due to how with_loss() works.
|
# sum() due to how with_loss() works.
|
||||||
x = with_loss(x, aux_loss)
|
x = with_loss(x, aux_loss, name)
|
||||||
# you must use x for something, or this will be ineffective.
|
# you must use x for something, or this will be ineffective.
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -907,17 +910,23 @@ class Whiten(nn.Module):
|
|||||||
|
|
||||||
class WithLoss(torch.autograd.Function):
|
class WithLoss(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, y: Tensor):
|
def forward(ctx, x: Tensor, y: Tensor, name: str):
|
||||||
ctx.y_shape = y.shape
|
ctx.name = name
|
||||||
|
ctx.save_for_backward(y) # just for printing the name, and the shape
|
||||||
return x
|
return x
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor):
|
def backward(ctx, ans_grad: Tensor):
|
||||||
return ans_grad, torch.ones(ctx.y_shape,
|
y, = ctx.saved_tensors
|
||||||
|
if random.random() < 0.002 and ctx.name is not None:
|
||||||
|
loss_sum = y.sum().item()
|
||||||
|
logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}")
|
||||||
|
|
||||||
|
return ans_grad, torch.ones(y.shape,
|
||||||
dtype=ans_grad.dtype,
|
dtype=ans_grad.dtype,
|
||||||
device=ans_grad.device)
|
device=ans_grad.device), None
|
||||||
def with_loss(x, y):
|
def with_loss(x, y, name):
|
||||||
# returns x but adds y.sum() to the loss function.
|
# returns x but adds y.sum() to the loss function.
|
||||||
return WithLoss.apply(x, y)
|
return WithLoss.apply(x, y, name)
|
||||||
|
|
||||||
|
|
||||||
class ScaleGradFunction(torch.autograd.Function):
|
class ScaleGradFunction(torch.autograd.Function):
|
||||||
|
|||||||
@ -801,6 +801,8 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
super(AttentionDownsample, self).__init__()
|
super(AttentionDownsample, self).__init__()
|
||||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||||
|
|
||||||
|
self.name = None # will be set from training code
|
||||||
|
|
||||||
# fill in the extra dimensions with a projection of the input
|
# fill in the extra dimensions with a projection of the input
|
||||||
if out_channels > in_channels:
|
if out_channels > in_channels:
|
||||||
self.extra_proj = nn.Linear(in_channels * downsample,
|
self.extra_proj = nn.Linear(in_channels * downsample,
|
||||||
@ -833,8 +835,9 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
scores = penalize_abs_values_gt(scores,
|
scores = penalize_abs_values_gt(scores,
|
||||||
limit=10.0,
|
limit=20.0,
|
||||||
penalty=1.0e-04)
|
penalty=1.0e-04,
|
||||||
|
name=self.name)
|
||||||
|
|
||||||
weights = scores.softmax(dim=1)
|
weights = scores.softmax(dim=1)
|
||||||
|
|
||||||
@ -1207,7 +1210,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# under normal circumstances.
|
# under normal circumstances.
|
||||||
attn_scores = penalize_abs_values_gt(attn_scores,
|
attn_scores = penalize_abs_values_gt(attn_scores,
|
||||||
limit=25.0,
|
limit=25.0,
|
||||||
penalty=1.0e-04)
|
penalty=1.0e-04,
|
||||||
|
name=self.name)
|
||||||
|
|
||||||
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||||
|
|
||||||
@ -1870,6 +1874,7 @@ class AttentionCombine(nn.Module):
|
|||||||
num_inputs))
|
num_inputs))
|
||||||
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
|
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
|
||||||
|
|
||||||
|
self.name = None # will be set from training code
|
||||||
assert 0 <= random_prob <= 1, random_prob
|
assert 0 <= random_prob <= 1, random_prob
|
||||||
assert 0 <= single_prob <= 1, single_prob
|
assert 0 <= single_prob <= 1, single_prob
|
||||||
|
|
||||||
@ -1926,7 +1931,8 @@ class AttentionCombine(nn.Module):
|
|||||||
if self.training and random.random() < 0.1:
|
if self.training and random.random() < 0.1:
|
||||||
scores = penalize_abs_values_gt(scores,
|
scores = penalize_abs_values_gt(scores,
|
||||||
limit=10.0,
|
limit=10.0,
|
||||||
penalty=1.0e-04)
|
penalty=1.0e-04,
|
||||||
|
name=self.name)
|
||||||
|
|
||||||
weights = scores.softmax(dim=1)
|
weights = scores.softmax(dim=1)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user