Remove LinearWithAuxLoss; simplify schedule of prob in ActivationBalancer.

This commit is contained in:
Daniel Povey 2022-12-16 15:07:42 +08:00
parent 3213c18a22
commit 56ac7354df
2 changed files with 28 additions and 160 deletions

View File

@ -407,120 +407,6 @@ class BasicNorm(torch.nn.Module):
return x * scales
class LinearWithAuxLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, weight: Tensor,
aux_grad_scale: float) -> Tensor:
"""
Returns matmul(x, weight.t()).
In the backward pass it will include an auxiliary loss based on predicting x from
matmul(y, weight).
"""
if torch.is_autocast_enabled():
x = x.to(torch.float16)
ctx.save_for_backward(x, weight)
ctx.aux_grad_scale = aux_grad_scale
return torch.matmul(x, weight.t())
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
x, weight = ctx.saved_tensors
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
x.reshape(-1, x.shape[-1]).to(ans_grad.dtype))
with torch.cuda.amp.autocast(enabled=False):
with torch.enable_grad():
x = x.to(weight.dtype)
x, weight = x.detach(), weight.detach()
weight.requires_grad = True
# recompute y as we need the gradient; this is easier to implement than
# saving y in the context.
y = torch.matmul(x, weight.t())
z = torch.matmul(y, weight)
# subtract mean
dims_to_mean = tuple(range(x.ndim-1))
x = x - x.mean(dim=dims_to_mean)
z = z - z.mean(dim=dims_to_mean)
# compute optimal scale on z
with torch.no_grad():
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
diff = x - alpha * z
# meansq is the loss function.
meansq = (diff ** 2).mean()
meansq.backward()
weight_aux_grad = weight.grad
with torch.cuda.amp.autocast(enabled=False):
weight_grad_norm = weight_grad.to(torch.float32).norm()
aux_grad_norm = weight_aux_grad.norm()
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
return x_grad, weight_grad, None
class LinearWithAuxLoss(nn.Module):
"""
A linear layer with an auxiliary loss that you can put on a schedule, that
encourages it to correspond to the largest-variance directions of the
input features.
Suppose the input is x, and this layer computes:
y = M x
(the bias is applied separately), then we define:
z = exp(alpha) * M^T y
where alpha is learnable; and the auxiliary loss will be:
aux_loss = normalize_mean(z - x)^2.
(normalize_mean refers to subtracting the average value per channel,
over the minibatch).
In the backward pass we compute the derivative of the auxiliary loss
and add it to the weight and bias grads, with a scale chosen such
that the extra grad's norm equals `aux_grad_scales` times the norm
of the existing grad.
"""
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
aux_grad_scale: Optional[FloatLike] = None,
prob: FloatLike = 0.25,
initial_scale: float = 1.0,
):
super().__init__()
if aux_grad_scale is None:
aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1),
(2000.0, 0.01), (8000.0, 0.0))
self.aux_grad_scale = aux_grad_scale
self.prob = prob
self.weight = nn.Parameter(torch.randn(out_channels, in_channels)
* (in_channels ** -0.5) * initial_scale)
if bias:
self.bias = nn.Parameter(torch.randn(out_channels) *
0.01 * initial_scale)
else:
self.register_parameter('bias', None)
def forward(self,
x: Tensor):
aux_grad_scale = float(self.aux_grad_scale)
if (not self.training or torch.jit.is_scripting() or
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
return torch.nn.functional.linear(x, self.weight, self.bias)
else:
ans = LinearWithAuxLossFunction.apply(x, self.weight,
aux_grad_scale)
if self.bias is not None:
ans += self.bias
return ans
def ScaledLinear(*args,
@ -655,12 +541,14 @@ class ActivationBalancer(torch.nn.Module):
scale_gain_factor: FloatLike = 0.04,
min_abs: FloatLike = 0.2,
max_abs: FloatLike = 100.0,
min_prob: FloatLike = 0.1,
prob: Optional[FloatLike] = None,
):
super(ActivationBalancer, self).__init__()
# CAUTION: this code expects self.batch_count to be overwritten in the main training
# loop.
self.batch_count = 0
if prob is None:
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1))
self.prob = prob
# actually self.num_channels is no longer needed except for an assertion.
self.num_channels = num_channels
@ -670,7 +558,6 @@ class ActivationBalancer(torch.nn.Module):
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
self.min_prob = min_prob
self.sign_gain_factor = sign_gain_factor
self.scale_gain_factor = scale_gain_factor
@ -682,9 +569,7 @@ class ActivationBalancer(torch.nn.Module):
if torch.jit.is_scripting() or not x.requires_grad:
return _no_op(x)
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
prob = float(self.prob)
if random.random() < prob:
assert x.shape[self.channel_dim] == self.num_channels

View File

@ -36,7 +36,6 @@ from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
LinearWithAuxLoss,
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt,
@ -358,13 +357,8 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
(20000.0, ratio * x),
default=x)
def _aux_grad_scale() -> float:
return 0.2
def _aux_grad_prob_out() -> ScheduledFloat:
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
def _aux_grad_prob_in() -> ScheduledFloat:
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.0))
#return ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
def _balancer_schedule(min_prob: float):
return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
@ -1351,15 +1345,11 @@ class AttentionSqueeze(nn.Module):
super().__init__()
self.bottleneck_dim = bottleneck_dim
self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim,
bias=False,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
bottleneck_dim,
aux_grad_scale=_aux_grad_scale(),
prob=_aux_grad_prob_in())
self.in_proj = nn.Linear(embed_dim, hidden_dim,
bias=False)
self.to_bottleneck_proj = nn.Linear(embed_dim,
bottleneck_dim)
# bottleneck_balancer is before the actiation. Mostly, for well-trained
# instances of this module, the mean absolute values per channel are in
@ -1370,7 +1360,6 @@ class AttentionSqueeze(nn.Module):
min_positive=0.2, max_positive=0.8,
min_abs=0.05,
max_abs=ScheduledFloat((0.0, 0.5), (4000.0, 1.0), default=1.0),
min_prob=0.1,
)
self.bottleneck_activation = TanSwish() # in bottleneck
self.activation = Identity() # for diagnostics
@ -1384,13 +1373,13 @@ class AttentionSqueeze(nn.Module):
hidden_dim, channel_dim=-1,
min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0,
min_prob=0.05,
prob=_balancer_schedule(0.05),
)
self.activation_balancer = ActivationBalancer(
hidden_dim, channel_dim=-1,
min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0,
min_prob=0.05,
prob=_balancer_schedule(0.05),
)
self.activation_whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
@ -1398,17 +1387,16 @@ class AttentionSqueeze(nn.Module):
grad_scale=0.01)
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim)
self.from_bottleneck_proj = nn.Linear(bottleneck_dim, hidden_dim)
self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim,
aux_grad_scale=_aux_grad_scale(),
prob=_aux_grad_prob_out(),
bias=False, initial_scale=0.05)
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
bias=False, initial_scale=0.05)
self.out_balancer = ActivationBalancer(
embed_dim, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
prob=0.05, # out of concern for memory usage
)
@ -1459,21 +1447,19 @@ class FeedforwardModule(nn.Module):
feedforward_dim: int,
dropout: FloatLike):
super(FeedforwardModule, self).__init__()
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
self.hidden_balancer = ActivationBalancer(feedforward_dim,
channel_dim=-1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
min_prob=0.25)
max_abs=5.0)
self.activation = SwooshL()
self.dropout = Dropout2(dropout)
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
initial_scale=0.01,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
initial_scale=0.01)
self.out_whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25),
@ -1544,10 +1530,10 @@ class NonlinAttentionModule(nn.Module):
channels, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
prob=0.05, # out of concern for memory usage
)
def forward(self,
x: Tensor,
attn_weights: Tensor,
@ -1615,9 +1601,8 @@ class ConvolutionModule(nn.Module):
bottleneck_dim = channels
self.in_proj = LinearWithAuxLoss(
self.in_proj = nn.Linear(
channels, 2 * bottleneck_dim,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()
)
@ -1673,9 +1658,8 @@ class ConvolutionModule(nn.Module):
prob=(0.025, 0.25),
grad_scale=0.01)
self.out_proj = LinearWithAuxLoss(
self.out_proj = ScaledLinear(
bottleneck_dim, channels,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
initial_scale=0.05,
)
@ -1818,8 +1802,7 @@ class Conv2dSubsampling(nn.Module):
self.scale_max = 1.0
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
self.out = nn.Linear(out_height * layer3_channels, out_channels)
self.dropout = Dropout2(dropout)