mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp168' into scaled_adam_exp169
This commit is contained in:
commit
9672dffac2
@ -35,7 +35,7 @@ from scaling import (
|
||||
Identity,
|
||||
_diag,
|
||||
random_clamp,
|
||||
with_loss,
|
||||
penalize_abs_values_gt,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -647,6 +647,11 @@ class AttentionDownsample(torch.nn.Module):
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=5.0,
|
||||
penalty=1.0e-04)
|
||||
|
||||
weights = scores.softmax(dim=1)
|
||||
|
||||
# ans1 is the first `in_channels` channels of the output
|
||||
@ -717,7 +722,13 @@ class SimpleCombiner(torch.nn.Module):
|
||||
|
||||
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
|
||||
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
|
||||
weight = (weight1 + weight2).sigmoid()
|
||||
logit = (weight1 + weight2)
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
logit = penalize_abs_values_gt(logit,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04)
|
||||
weight = logit.sigmoid()
|
||||
|
||||
src2_part1 = src2[...,:dim1]
|
||||
part1 = src1 * weight + src2_part1 * (1.0 - weight)
|
||||
@ -1116,15 +1127,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# this mechanism instead of, say, a limit on entropy, because once the entropy
|
||||
# gets very small gradients through the softmax can become very small, and
|
||||
# some mechanisms like that become ineffective.
|
||||
attn_weights_limit = 50.0
|
||||
# caution: this penalty will be affected by grad-scaling in amp.
|
||||
# It's OK; this is just an emergency brake, and under normal
|
||||
# conditions it shouldn't be active
|
||||
attn_weights_penalty = 1.0e-04
|
||||
aux_loss = attn_weights_penalty * (attn_output_weights.abs() -
|
||||
attn_weights_limit).relu()
|
||||
attn_output_weights = with_loss(attn_output_weights,
|
||||
aux_loss)
|
||||
attn_output_weights = penalize_abs_values_gt(attn_output_weights,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04)
|
||||
|
||||
|
||||
# attn_output_weights: (batch, head, time1, time2)
|
||||
@ -1545,22 +1550,22 @@ class AttentionCombine(nn.Module):
|
||||
(num_frames, num_channels, num_inputs)
|
||||
)
|
||||
|
||||
weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
|
||||
scores = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
|
||||
|
||||
if random.random() < 0.002:
|
||||
logging.info(f"Average weights are {weights.softmax(dim=1).mean(dim=0)}")
|
||||
logging.info(f"Average scores are {scores.softmax(dim=1).mean(dim=0)}")
|
||||
|
||||
if self.training:
|
||||
# random masking..
|
||||
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
||||
size=(num_frames,), device=weights.device).unsqueeze(1)
|
||||
size=(num_frames,), device=scores.device).unsqueeze(1)
|
||||
# mask will have rows like: [ False, False, False, True, True, .. ]
|
||||
arange = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
|
||||
arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand(
|
||||
num_frames, num_inputs)
|
||||
mask = arange >= mask_start
|
||||
|
||||
apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1),
|
||||
device=weights.device) < self.single_prob,
|
||||
device=scores.device) < self.single_prob,
|
||||
mask_start < num_inputs)
|
||||
single_prob_mask = torch.logical_and(apply_single_prob,
|
||||
arange < mask_start - 1)
|
||||
@ -1568,8 +1573,14 @@ class AttentionCombine(nn.Module):
|
||||
mask = torch.logical_or(mask,
|
||||
single_prob_mask)
|
||||
|
||||
weights = weights.masked_fill(mask, float('-inf'))
|
||||
weights = weights.softmax(dim=1)
|
||||
scores = scores.masked_fill(mask, float('-inf'))
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=10.0,
|
||||
penalty=1.0e-04)
|
||||
|
||||
weights = scores.softmax(dim=1)
|
||||
|
||||
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
|
||||
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))
|
||||
|
||||
@ -545,6 +545,24 @@ class ActivationBalancer(torch.nn.Module):
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
||||
"""
|
||||
Returns x unmodified, but in backprop will put a penalty for the excess of
|
||||
the absolute values of elements of x over the limit "limit". E.g. if
|
||||
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
||||
|
||||
Caution: the value of this penalty will be affected by grad scaling used
|
||||
in automatic mixed precision training. For this reasons we use this,
|
||||
it shouldn't really matter, or may even be helpful; we just use this
|
||||
to disallow really implausible values of scores to be given to softmax.
|
||||
"""
|
||||
aux_loss = penalty * (x.abs() - limit).relu()
|
||||
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
||||
# sum() due to how with_loss() works.
|
||||
x = with_loss(x, aux_loss)
|
||||
# you must use x for something, or this will be ineffective.
|
||||
return x
|
||||
|
||||
|
||||
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
||||
if x.ndim == 2:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user