Use LinearWithAuxLoss in squeeze-attention module

This commit is contained in:
Daniel Povey 2022-11-25 16:02:00 +08:00
parent ba348169bf
commit 6a91f343e9
2 changed files with 195 additions and 88 deletions

View File

@ -29,6 +29,76 @@ from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding
class ScheduledFloat(torch.nn.Module):
"""
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
it does not have a working forward() function. You are supposed to cast it to float, as
in, float(parent_module.whatever), and use it as something like a dropout prob.
It is a floating point value whose value changes depending on the batch count of the
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
in sorted order on x; x corresponds to the batch index. For batch-index values before the
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
`default` is used when self.batch_count is not set or in training or mode or in
torch.jit scripting mode.
"""
def __init__(self,
*args,
default: float = 0.0):
super().__init__()
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = None
self.name = None
self.default = default
assert len(args) >= 1
for (x,y) in args:
assert x >= 0
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.schedule = args
def extra_repr(self) -> str:
return 'batch_count={}, schedule={}'.format(self.batch_count,
self.schedule)
def __float__(self):
print_prob = 0.0002
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting():
return float(self.default)
if batch_count <= self.schedule[0][0]:
ans = self.schedule[0][1]
maybe_print(ans)
return float(ans)
elif batch_count >= self.schedule[-1][0]:
ans = self.schedule[-1][1]
maybe_print(ans)
return float(ans)
else:
cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)):
next_x, next_y = self.schedule[i]
if batch_count >= cur_x and batch_count <= next_x:
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans)
return float(ans)
cur_x, cur_y = next_x, next_y
assert False
FloatLike = Union[float, ScheduledFloat]
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(
@ -317,6 +387,110 @@ class BasicNorm(torch.nn.Module):
return x * scales
class LinearWithAuxLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, weight: Tensor, alpha: Tensor,
aux_grad_scale: float) -> Tensor:
"""
Returns matmul(x, weight.t()).
In the backward pass it will include an auxiliary loss based on predicting x from
matmul(y, weight).
"""
ctx.save_for_backward(x, weight, alpha)
ctx.aux_grad_scale = aux_grad_scale
return torch.matmul(x, weight.t())
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
x, weight, alpha = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False):
with torch.enable_grad():
x, weight, alpha = x.detach(), weight.detach(), alpha.detach()
weight.requires_grad = True
alpha.requires_grad = True
# recompute y as we need the gradient; this is easier to implement than
# saving y in the context.
y = torch.matmul(x, weight.t())
z = alpha * torch.matmul(y, weight)
diff = x - z
dims_to_mean = tuple(range(x.ndim-1))
mean = diff.mean(dim=dims_to_mean)
diff = diff - mean # subtract mean.
# meansq is the loss function.
meansq = (diff ** 2).mean()
meansq.backward()
weight_aux_grad = weight.grad
alpha_grad = alpha.grad
x_grad = torch.matmul(ans_grad, weight)
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
x.reshape(-1, x.shape[-1]))
with torch.cuda.amp.autocast(enabled=False):
weight_grad_norm = weight_grad.to(torch.float32).norm()
aux_grad_norm = weight_aux_grad.norm()
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
return x_grad, weight_grad, alpha_grad, None
class LinearWithAuxLoss(nn.Module):
"""
A linear layer with an auxiliary loss that you can put on a schedule, that
encourages it to correspond to the largest-variance directions of the
input features.
Suppose the input is x, and this layer computes:
y = M x
(the bias is applied separately), then we define:
z = alpha * M^T y
where alpha is learnable; and the auxiliary loss will be:
aux_loss = normalize_mean(z - x)^2.
(normalize_mean refers to subtracting the average value per channel,
over the minibatch).
In the backward pass we compute the derivative of the auxiliary loss
and add it to the weight and bias grads, with a scale chosen such
that the extra grad's norm equals `aux_grad_scales` times the norm
of the existing grad.
"""
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
aux_grad_scale: Optional[FloatLike] = None,
prob: FloatLike = 0.25,
initial_scale: float = 1.0,
):
super().__init__()
if aux_grad_scale is None:
aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1),
(2000.0, 0.01), (8000.0, 0.0))
self.aux_grad_scale = aux_grad_scale
self.prob = prob
self.weight = nn.Parameter(torch.randn(out_channels, in_channels)
* (in_channels ** -0.5) * initial_scale)
if bias:
self.bias = nn.Parameter(torch.randn(out_channels) *
0.01 * initial_scale)
else:
self.register_parameter('bias', None)
self.alpha = nn.Parameter(torch.tensor(1.0))
def forward(self,
x: Tensor):
aux_grad_scale = float(self.aux_grad_scale)
if (not self.training or torch.jit.is_scripting() or
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
return torch.matmul(x, self.weight.t()) + self.bias
else:
return LinearWithAuxLossFunction.apply(x, self.weight, self.alpha,
aux_grad_scale) + self.bias
def ScaledLinear(*args,
initial_scale: float = 1.0,
@ -417,14 +591,14 @@ class ActivationBalancer(torch.nn.Module):
self,
num_channels: int,
channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.04,
sign_gain_factor: float = 0.01,
scale_gain_factor: float = 0.02,
min_abs: float = 0.2,
max_abs: float = 100.0,
min_prob: float = 0.1,
min_positive: FloatLike = 0.05,
max_positive: FloatLike = 0.95,
max_factor: FloatLike = 0.04,
sign_gain_factor: FloatLike = 0.01,
scale_gain_factor: FloatLike = 0.02,
min_abs: FloatLike = 0.2,
max_abs: FloatLike = 100.0,
min_prob: FloatLike = 0.1,
):
super(ActivationBalancer, self).__init__()
# CAUTION: this code expects self.batch_count to be overwritten in the main training
@ -453,25 +627,26 @@ class ActivationBalancer(torch.nn.Module):
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
if random.random() < prob:
assert x.shape[self.channel_dim] == self.num_channels
sign_gain_factor = 0.5
if self.min_positive != 0.0 or self.max_positive != 1.0:
if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0:
sign_factor = _compute_sign_factor(x, self.channel_dim,
self.min_positive, self.max_positive,
gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor)
float(self.min_positive),
float(self.max_positive),
gain_factor=float(self.sign_gain_factor) / prob,
max_factor=float(self.max_factor))
else:
sign_factor = None
scale_factor = _compute_scale_factor(x, self.channel_dim,
min_abs=self.min_abs,
max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor)
min_abs=float(self.min_abs),
max_abs=float(self.max_abs),
gain_factor=float(self.scale_gain_factor) / prob,
max_factor=float(self.max_factor))
return ActivationBalancerFunction.apply(
x, scale_factor, sign_factor, self.channel_dim,
)
@ -519,74 +694,6 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
class ScheduledFloat(torch.nn.Module):
"""
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
it does not have a working forward() function. You are supposed to cast it to float, as
in, float(parent_module.whatever), and use it as something like a dropout prob.
It is a floating point value whose value changes depending on the batch count of the
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
in sorted order on x; x corresponds to the batch index. For batch-index values before the
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
`default` is used when self.batch_count is not set or in training or mode or in
torch.jit scripting mode.
"""
def __init__(self,
*args,
default: float = 0.0):
super().__init__()
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = None
self.name = None
self.default = default
assert len(args) >= 1
for (x,y) in args:
assert x >= 0
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.schedule = args
def extra_repr(self) -> str:
return 'batch_count={}, schedule={}'.format(self.batch_count,
self.schedule)
def __float__(self):
print_prob = 0.0002
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting():
return float(self.default)
if batch_count <= self.schedule[0][0]:
ans = self.schedule[0][1]
maybe_print(ans)
return float(ans)
elif batch_count >= self.schedule[-1][0]:
ans = self.schedule[-1][1]
maybe_print(ans)
return float(ans)
else:
cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)):
next_x, next_y = self.schedule[i]
if batch_count >= cur_x and batch_count <= next_x:
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans)
return float(ans)
cur_x, cur_y = next_x, next_y
assert False
FloatLike = Union[float, ScheduledFloat]
def _whitening_metric(x: Tensor,
num_groups: int):
"""

View File

@ -32,6 +32,7 @@ from scaling import (
TanSwish,
ScaledConv1d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
LinearWithAuxLoss,
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt,
@ -1288,9 +1289,8 @@ class AttentionSqueeze(nn.Module):
self.in_proj = nn.Linear(embed_dim, embed_dim,
bias=False)
self.to_bottleneck_proj = nn.Linear(embed_dim,
bottleneck_dim,
bias=False)
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
bottleneck_dim)
# the main reason for this balancer is to keep the bottleneck activations in a "reasonable"