mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use LinearWithAuxLoss in squeeze-attention module
This commit is contained in:
parent
ba348169bf
commit
6a91f343e9
@ -29,6 +29,76 @@ from torch import Tensor
|
||||
from torch.nn import Embedding as ScaledEmbedding
|
||||
|
||||
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||
|
||||
It is a floating point value whose value changes depending on the batch count of the
|
||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
Example:
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||
|
||||
`default` is used when self.batch_count is not set or in training or mode or in
|
||||
torch.jit scripting mode.
|
||||
"""
|
||||
def __init__(self,
|
||||
*args,
|
||||
default: float = 0.0):
|
||||
super().__init__()
|
||||
# self.batch_count and self.name will be written to in the training loop.
|
||||
self.batch_count = None
|
||||
self.name = None
|
||||
self.default = default
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
for i in range(len(args) - 1):
|
||||
assert args[i + 1] > args[i], args
|
||||
self.schedule = args
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||
self.schedule)
|
||||
|
||||
def __float__(self):
|
||||
print_prob = 0.0002
|
||||
def maybe_print(ans):
|
||||
if random.random() < print_prob:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
return float(self.default)
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
ans = self.schedule[0][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
ans = self.schedule[-1][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
next_x, next_y = self.schedule[i]
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
|
||||
FloatLike = Union[float, ScheduledFloat]
|
||||
|
||||
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@ -317,6 +387,110 @@ class BasicNorm(torch.nn.Module):
|
||||
return x * scales
|
||||
|
||||
|
||||
class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, weight: Tensor, alpha: Tensor,
|
||||
aux_grad_scale: float) -> Tensor:
|
||||
"""
|
||||
Returns matmul(x, weight.t()).
|
||||
In the backward pass it will include an auxiliary loss based on predicting x from
|
||||
matmul(y, weight).
|
||||
"""
|
||||
ctx.save_for_backward(x, weight, alpha)
|
||||
ctx.aux_grad_scale = aux_grad_scale
|
||||
return torch.matmul(x, weight.t())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
|
||||
x, weight, alpha = ctx.saved_tensors
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.enable_grad():
|
||||
x, weight, alpha = x.detach(), weight.detach(), alpha.detach()
|
||||
weight.requires_grad = True
|
||||
alpha.requires_grad = True
|
||||
# recompute y as we need the gradient; this is easier to implement than
|
||||
# saving y in the context.
|
||||
y = torch.matmul(x, weight.t())
|
||||
z = alpha * torch.matmul(y, weight)
|
||||
diff = x - z
|
||||
dims_to_mean = tuple(range(x.ndim-1))
|
||||
mean = diff.mean(dim=dims_to_mean)
|
||||
diff = diff - mean # subtract mean.
|
||||
# meansq is the loss function.
|
||||
meansq = (diff ** 2).mean()
|
||||
meansq.backward()
|
||||
weight_aux_grad = weight.grad
|
||||
alpha_grad = alpha.grad
|
||||
|
||||
x_grad = torch.matmul(ans_grad, weight)
|
||||
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
|
||||
x.reshape(-1, x.shape[-1]))
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
weight_grad_norm = weight_grad.to(torch.float32).norm()
|
||||
aux_grad_norm = weight_aux_grad.norm()
|
||||
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
|
||||
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
|
||||
|
||||
return x_grad, weight_grad, alpha_grad, None
|
||||
|
||||
|
||||
|
||||
class LinearWithAuxLoss(nn.Module):
|
||||
"""
|
||||
A linear layer with an auxiliary loss that you can put on a schedule, that
|
||||
encourages it to correspond to the largest-variance directions of the
|
||||
input features.
|
||||
|
||||
Suppose the input is x, and this layer computes:
|
||||
y = M x
|
||||
(the bias is applied separately), then we define:
|
||||
z = alpha * M^T y
|
||||
where alpha is learnable; and the auxiliary loss will be:
|
||||
aux_loss = normalize_mean(z - x)^2.
|
||||
(normalize_mean refers to subtracting the average value per channel,
|
||||
over the minibatch).
|
||||
In the backward pass we compute the derivative of the auxiliary loss
|
||||
and add it to the weight and bias grads, with a scale chosen such
|
||||
that the extra grad's norm equals `aux_grad_scales` times the norm
|
||||
of the existing grad.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias: bool = True,
|
||||
aux_grad_scale: Optional[FloatLike] = None,
|
||||
prob: FloatLike = 0.25,
|
||||
initial_scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
if aux_grad_scale is None:
|
||||
aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1),
|
||||
(2000.0, 0.01), (8000.0, 0.0))
|
||||
|
||||
self.aux_grad_scale = aux_grad_scale
|
||||
self.prob = prob
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels)
|
||||
* (in_channels ** -0.5) * initial_scale)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.randn(out_channels) *
|
||||
0.01 * initial_scale)
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.alpha = nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor):
|
||||
aux_grad_scale = float(self.aux_grad_scale)
|
||||
if (not self.training or torch.jit.is_scripting() or
|
||||
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
|
||||
return torch.matmul(x, self.weight.t()) + self.bias
|
||||
else:
|
||||
return LinearWithAuxLossFunction.apply(x, self.weight, self.alpha,
|
||||
aux_grad_scale) + self.bias
|
||||
|
||||
def ScaledLinear(*args,
|
||||
initial_scale: float = 1.0,
|
||||
@ -417,14 +591,14 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
max_positive: float = 0.95,
|
||||
max_factor: float = 0.04,
|
||||
sign_gain_factor: float = 0.01,
|
||||
scale_gain_factor: float = 0.02,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0,
|
||||
min_prob: float = 0.1,
|
||||
min_positive: FloatLike = 0.05,
|
||||
max_positive: FloatLike = 0.95,
|
||||
max_factor: FloatLike = 0.04,
|
||||
sign_gain_factor: FloatLike = 0.01,
|
||||
scale_gain_factor: FloatLike = 0.02,
|
||||
min_abs: FloatLike = 0.2,
|
||||
max_abs: FloatLike = 100.0,
|
||||
min_prob: FloatLike = 0.1,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
||||
@ -453,25 +627,26 @@ class ActivationBalancer(torch.nn.Module):
|
||||
|
||||
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
||||
# a floor at min_prob (==0.1, by default)
|
||||
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||
|
||||
if random.random() < prob:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
sign_gain_factor = 0.5
|
||||
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
||||
if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0:
|
||||
sign_factor = _compute_sign_factor(x, self.channel_dim,
|
||||
self.min_positive, self.max_positive,
|
||||
gain_factor=self.sign_gain_factor / prob,
|
||||
max_factor=self.max_factor)
|
||||
float(self.min_positive),
|
||||
float(self.max_positive),
|
||||
gain_factor=float(self.sign_gain_factor) / prob,
|
||||
max_factor=float(self.max_factor))
|
||||
else:
|
||||
sign_factor = None
|
||||
|
||||
|
||||
scale_factor = _compute_scale_factor(x, self.channel_dim,
|
||||
min_abs=self.min_abs,
|
||||
max_abs=self.max_abs,
|
||||
gain_factor=self.scale_gain_factor / prob,
|
||||
max_factor=self.max_factor)
|
||||
min_abs=float(self.min_abs),
|
||||
max_abs=float(self.max_abs),
|
||||
gain_factor=float(self.scale_gain_factor) / prob,
|
||||
max_factor=float(self.max_factor))
|
||||
return ActivationBalancerFunction.apply(
|
||||
x, scale_factor, sign_factor, self.channel_dim,
|
||||
)
|
||||
@ -519,74 +694,6 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
||||
|
||||
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||
|
||||
It is a floating point value whose value changes depending on the batch count of the
|
||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
Example:
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||
|
||||
`default` is used when self.batch_count is not set or in training or mode or in
|
||||
torch.jit scripting mode.
|
||||
"""
|
||||
def __init__(self,
|
||||
*args,
|
||||
default: float = 0.0):
|
||||
super().__init__()
|
||||
# self.batch_count and self.name will be written to in the training loop.
|
||||
self.batch_count = None
|
||||
self.name = None
|
||||
self.default = default
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
for i in range(len(args) - 1):
|
||||
assert args[i + 1] > args[i], args
|
||||
self.schedule = args
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||
self.schedule)
|
||||
|
||||
def __float__(self):
|
||||
print_prob = 0.0002
|
||||
def maybe_print(ans):
|
||||
if random.random() < print_prob:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
return float(self.default)
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
ans = self.schedule[0][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
ans = self.schedule[-1][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
next_x, next_y = self.schedule[i]
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
|
||||
FloatLike = Union[float, ScheduledFloat]
|
||||
|
||||
|
||||
def _whitening_metric(x: Tensor,
|
||||
num_groups: int):
|
||||
"""
|
||||
|
||||
@ -32,6 +32,7 @@ from scaling import (
|
||||
TanSwish,
|
||||
ScaledConv1d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
LinearWithAuxLoss,
|
||||
Whiten,
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
penalize_abs_values_gt,
|
||||
@ -1288,9 +1289,8 @@ class AttentionSqueeze(nn.Module):
|
||||
self.in_proj = nn.Linear(embed_dim, embed_dim,
|
||||
bias=False)
|
||||
|
||||
self.to_bottleneck_proj = nn.Linear(embed_dim,
|
||||
bottleneck_dim,
|
||||
bias=False)
|
||||
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
|
||||
bottleneck_dim)
|
||||
|
||||
|
||||
# the main reason for this balancer is to keep the bottleneck activations in a "reasonable"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user