mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove LinearWithAuxLoss; simplify schedule of prob in ActivationBalancer.
This commit is contained in:
parent
3213c18a22
commit
56ac7354df
@ -407,120 +407,6 @@ class BasicNorm(torch.nn.Module):
|
||||
return x * scales
|
||||
|
||||
|
||||
class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, weight: Tensor,
|
||||
aux_grad_scale: float) -> Tensor:
|
||||
"""
|
||||
Returns matmul(x, weight.t()).
|
||||
In the backward pass it will include an auxiliary loss based on predicting x from
|
||||
matmul(y, weight).
|
||||
"""
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(torch.float16)
|
||||
ctx.save_for_backward(x, weight)
|
||||
ctx.aux_grad_scale = aux_grad_scale
|
||||
return torch.matmul(x, weight.t())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
|
||||
x, weight = ctx.saved_tensors
|
||||
|
||||
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
|
||||
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
|
||||
x.reshape(-1, x.shape[-1]).to(ans_grad.dtype))
|
||||
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.enable_grad():
|
||||
x = x.to(weight.dtype)
|
||||
x, weight = x.detach(), weight.detach()
|
||||
weight.requires_grad = True
|
||||
# recompute y as we need the gradient; this is easier to implement than
|
||||
# saving y in the context.
|
||||
y = torch.matmul(x, weight.t())
|
||||
z = torch.matmul(y, weight)
|
||||
# subtract mean
|
||||
dims_to_mean = tuple(range(x.ndim-1))
|
||||
x = x - x.mean(dim=dims_to_mean)
|
||||
z = z - z.mean(dim=dims_to_mean)
|
||||
# compute optimal scale on z
|
||||
with torch.no_grad():
|
||||
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
|
||||
diff = x - alpha * z
|
||||
# meansq is the loss function.
|
||||
meansq = (diff ** 2).mean()
|
||||
meansq.backward()
|
||||
weight_aux_grad = weight.grad
|
||||
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
weight_grad_norm = weight_grad.to(torch.float32).norm()
|
||||
aux_grad_norm = weight_aux_grad.norm()
|
||||
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
|
||||
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
|
||||
|
||||
return x_grad, weight_grad, None
|
||||
|
||||
|
||||
|
||||
class LinearWithAuxLoss(nn.Module):
|
||||
"""
|
||||
A linear layer with an auxiliary loss that you can put on a schedule, that
|
||||
encourages it to correspond to the largest-variance directions of the
|
||||
input features.
|
||||
|
||||
Suppose the input is x, and this layer computes:
|
||||
y = M x
|
||||
(the bias is applied separately), then we define:
|
||||
z = exp(alpha) * M^T y
|
||||
where alpha is learnable; and the auxiliary loss will be:
|
||||
aux_loss = normalize_mean(z - x)^2.
|
||||
(normalize_mean refers to subtracting the average value per channel,
|
||||
over the minibatch).
|
||||
In the backward pass we compute the derivative of the auxiliary loss
|
||||
and add it to the weight and bias grads, with a scale chosen such
|
||||
that the extra grad's norm equals `aux_grad_scales` times the norm
|
||||
of the existing grad.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias: bool = True,
|
||||
aux_grad_scale: Optional[FloatLike] = None,
|
||||
prob: FloatLike = 0.25,
|
||||
initial_scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
if aux_grad_scale is None:
|
||||
aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1),
|
||||
(2000.0, 0.01), (8000.0, 0.0))
|
||||
|
||||
self.aux_grad_scale = aux_grad_scale
|
||||
self.prob = prob
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels)
|
||||
* (in_channels ** -0.5) * initial_scale)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.randn(out_channels) *
|
||||
0.01 * initial_scale)
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor):
|
||||
aux_grad_scale = float(self.aux_grad_scale)
|
||||
if (not self.training or torch.jit.is_scripting() or
|
||||
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
|
||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
else:
|
||||
ans = LinearWithAuxLossFunction.apply(x, self.weight,
|
||||
aux_grad_scale)
|
||||
if self.bias is not None:
|
||||
ans += self.bias
|
||||
return ans
|
||||
|
||||
|
||||
def ScaledLinear(*args,
|
||||
@ -655,12 +541,14 @@ class ActivationBalancer(torch.nn.Module):
|
||||
scale_gain_factor: FloatLike = 0.04,
|
||||
min_abs: FloatLike = 0.2,
|
||||
max_abs: FloatLike = 100.0,
|
||||
min_prob: FloatLike = 0.1,
|
||||
prob: Optional[FloatLike] = None,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
||||
# loop.
|
||||
self.batch_count = 0
|
||||
|
||||
|
||||
if prob is None:
|
||||
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1))
|
||||
self.prob = prob
|
||||
|
||||
# actually self.num_channels is no longer needed except for an assertion.
|
||||
self.num_channels = num_channels
|
||||
@ -670,7 +558,6 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_factor = max_factor
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
self.min_prob = min_prob
|
||||
self.sign_gain_factor = sign_gain_factor
|
||||
self.scale_gain_factor = scale_gain_factor
|
||||
|
||||
@ -682,9 +569,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
if torch.jit.is_scripting() or not x.requires_grad:
|
||||
return _no_op(x)
|
||||
|
||||
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
||||
# a floor at min_prob (==0.1, by default)
|
||||
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||
prob = float(self.prob)
|
||||
|
||||
if random.random() < prob:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
|
||||
@ -36,7 +36,6 @@ from scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
LinearWithAuxLoss,
|
||||
Whiten,
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
penalize_abs_values_gt,
|
||||
@ -358,13 +357,8 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
||||
(20000.0, ratio * x),
|
||||
default=x)
|
||||
|
||||
def _aux_grad_scale() -> float:
|
||||
return 0.2
|
||||
def _aux_grad_prob_out() -> ScheduledFloat:
|
||||
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
|
||||
def _aux_grad_prob_in() -> ScheduledFloat:
|
||||
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.0))
|
||||
#return ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
|
||||
def _balancer_schedule(min_prob: float):
|
||||
return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
|
||||
|
||||
|
||||
|
||||
@ -1351,15 +1345,11 @@ class AttentionSqueeze(nn.Module):
|
||||
super().__init__()
|
||||
self.bottleneck_dim = bottleneck_dim
|
||||
|
||||
self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim,
|
||||
bias=False,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
||||
|
||||
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
|
||||
bottleneck_dim,
|
||||
aux_grad_scale=_aux_grad_scale(),
|
||||
prob=_aux_grad_prob_in())
|
||||
self.in_proj = nn.Linear(embed_dim, hidden_dim,
|
||||
bias=False)
|
||||
|
||||
self.to_bottleneck_proj = nn.Linear(embed_dim,
|
||||
bottleneck_dim)
|
||||
|
||||
# bottleneck_balancer is before the actiation. Mostly, for well-trained
|
||||
# instances of this module, the mean absolute values per channel are in
|
||||
@ -1370,7 +1360,6 @@ class AttentionSqueeze(nn.Module):
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.05,
|
||||
max_abs=ScheduledFloat((0.0, 0.5), (4000.0, 1.0), default=1.0),
|
||||
min_prob=0.1,
|
||||
)
|
||||
self.bottleneck_activation = TanSwish() # in bottleneck
|
||||
self.activation = Identity() # for diagnostics
|
||||
@ -1384,13 +1373,13 @@ class AttentionSqueeze(nn.Module):
|
||||
hidden_dim, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
min_prob=0.05,
|
||||
prob=_balancer_schedule(0.05),
|
||||
)
|
||||
self.activation_balancer = ActivationBalancer(
|
||||
hidden_dim, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
min_prob=0.05,
|
||||
prob=_balancer_schedule(0.05),
|
||||
)
|
||||
self.activation_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
||||
@ -1398,17 +1387,16 @@ class AttentionSqueeze(nn.Module):
|
||||
grad_scale=0.01)
|
||||
|
||||
|
||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim)
|
||||
self.from_bottleneck_proj = nn.Linear(bottleneck_dim, hidden_dim)
|
||||
|
||||
self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim,
|
||||
aux_grad_scale=_aux_grad_scale(),
|
||||
prob=_aux_grad_prob_out(),
|
||||
bias=False, initial_scale=0.05)
|
||||
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
|
||||
bias=False, initial_scale=0.05)
|
||||
|
||||
self.out_balancer = ActivationBalancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.3, max_positive=0.7,
|
||||
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
|
||||
prob=0.05, # out of concern for memory usage
|
||||
)
|
||||
|
||||
|
||||
@ -1459,21 +1447,19 @@ class FeedforwardModule(nn.Module):
|
||||
feedforward_dim: int,
|
||||
dropout: FloatLike):
|
||||
super(FeedforwardModule, self).__init__()
|
||||
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
||||
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
||||
|
||||
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.3,
|
||||
max_positive=1.0,
|
||||
min_abs=0.75,
|
||||
max_abs=5.0,
|
||||
min_prob=0.25)
|
||||
max_abs=5.0)
|
||||
self.activation = SwooshL()
|
||||
self.dropout = Dropout2(dropout)
|
||||
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
|
||||
initial_scale=0.01,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||
initial_scale=0.01)
|
||||
|
||||
self.out_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
@ -1544,10 +1530,10 @@ class NonlinAttentionModule(nn.Module):
|
||||
channels, channel_dim=-1,
|
||||
min_positive=0.3, max_positive=0.7,
|
||||
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
|
||||
prob=0.05, # out of concern for memory usage
|
||||
)
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
@ -1615,9 +1601,8 @@ class ConvolutionModule(nn.Module):
|
||||
bottleneck_dim = channels
|
||||
|
||||
|
||||
self.in_proj = LinearWithAuxLoss(
|
||||
self.in_proj = nn.Linear(
|
||||
channels, 2 * bottleneck_dim,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()
|
||||
)
|
||||
|
||||
|
||||
@ -1673,9 +1658,8 @@ class ConvolutionModule(nn.Module):
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
self.out_proj = LinearWithAuxLoss(
|
||||
self.out_proj = ScaledLinear(
|
||||
bottleneck_dim, channels,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
|
||||
initial_scale=0.05,
|
||||
)
|
||||
|
||||
@ -1818,8 +1802,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
self.scale_max = 1.0
|
||||
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
|
||||
|
||||
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
||||
self.out = nn.Linear(out_height * layer3_channels, out_channels)
|
||||
|
||||
self.dropout = Dropout2(dropout)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user