mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use LinearWithAuxLoss in squeeze-attention module
This commit is contained in:
parent
ba348169bf
commit
6a91f343e9
@ -29,6 +29,76 @@ from torch import Tensor
|
|||||||
from torch.nn import Embedding as ScaledEmbedding
|
from torch.nn import Embedding as ScaledEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledFloat(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||||
|
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||||
|
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||||
|
|
||||||
|
It is a floating point value whose value changes depending on the batch count of the
|
||||||
|
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||||
|
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||||
|
first x or after the last x, we just use the first or last y value.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||||
|
|
||||||
|
`default` is used when self.batch_count is not set or in training or mode or in
|
||||||
|
torch.jit scripting mode.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
*args,
|
||||||
|
default: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
# self.batch_count and self.name will be written to in the training loop.
|
||||||
|
self.batch_count = None
|
||||||
|
self.name = None
|
||||||
|
self.default = default
|
||||||
|
assert len(args) >= 1
|
||||||
|
for (x,y) in args:
|
||||||
|
assert x >= 0
|
||||||
|
for i in range(len(args) - 1):
|
||||||
|
assert args[i + 1] > args[i], args
|
||||||
|
self.schedule = args
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||||
|
self.schedule)
|
||||||
|
|
||||||
|
def __float__(self):
|
||||||
|
print_prob = 0.0002
|
||||||
|
def maybe_print(ans):
|
||||||
|
if random.random() < print_prob:
|
||||||
|
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||||
|
batch_count = self.batch_count
|
||||||
|
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||||
|
return float(self.default)
|
||||||
|
if batch_count <= self.schedule[0][0]:
|
||||||
|
ans = self.schedule[0][1]
|
||||||
|
maybe_print(ans)
|
||||||
|
return float(ans)
|
||||||
|
elif batch_count >= self.schedule[-1][0]:
|
||||||
|
ans = self.schedule[-1][1]
|
||||||
|
maybe_print(ans)
|
||||||
|
return float(ans)
|
||||||
|
else:
|
||||||
|
cur_x, cur_y = self.schedule[0]
|
||||||
|
for i in range(1, len(self.schedule)):
|
||||||
|
next_x, next_y = self.schedule[i]
|
||||||
|
if batch_count >= cur_x and batch_count <= next_x:
|
||||||
|
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||||
|
maybe_print(ans)
|
||||||
|
return float(ans)
|
||||||
|
cur_x, cur_y = next_x, next_y
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
FloatLike = Union[float, ScheduledFloat]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancerFunction(torch.autograd.Function):
|
class ActivationBalancerFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
@ -317,6 +387,110 @@ class BasicNorm(torch.nn.Module):
|
|||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
|
|
||||||
|
class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, weight: Tensor, alpha: Tensor,
|
||||||
|
aux_grad_scale: float) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns matmul(x, weight.t()).
|
||||||
|
In the backward pass it will include an auxiliary loss based on predicting x from
|
||||||
|
matmul(y, weight).
|
||||||
|
"""
|
||||||
|
ctx.save_for_backward(x, weight, alpha)
|
||||||
|
ctx.aux_grad_scale = aux_grad_scale
|
||||||
|
return torch.matmul(x, weight.t())
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
|
||||||
|
x, weight, alpha = ctx.saved_tensors
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
with torch.enable_grad():
|
||||||
|
x, weight, alpha = x.detach(), weight.detach(), alpha.detach()
|
||||||
|
weight.requires_grad = True
|
||||||
|
alpha.requires_grad = True
|
||||||
|
# recompute y as we need the gradient; this is easier to implement than
|
||||||
|
# saving y in the context.
|
||||||
|
y = torch.matmul(x, weight.t())
|
||||||
|
z = alpha * torch.matmul(y, weight)
|
||||||
|
diff = x - z
|
||||||
|
dims_to_mean = tuple(range(x.ndim-1))
|
||||||
|
mean = diff.mean(dim=dims_to_mean)
|
||||||
|
diff = diff - mean # subtract mean.
|
||||||
|
# meansq is the loss function.
|
||||||
|
meansq = (diff ** 2).mean()
|
||||||
|
meansq.backward()
|
||||||
|
weight_aux_grad = weight.grad
|
||||||
|
alpha_grad = alpha.grad
|
||||||
|
|
||||||
|
x_grad = torch.matmul(ans_grad, weight)
|
||||||
|
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
|
||||||
|
x.reshape(-1, x.shape[-1]))
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
weight_grad_norm = weight_grad.to(torch.float32).norm()
|
||||||
|
aux_grad_norm = weight_aux_grad.norm()
|
||||||
|
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
|
||||||
|
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
|
||||||
|
|
||||||
|
return x_grad, weight_grad, alpha_grad, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LinearWithAuxLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
A linear layer with an auxiliary loss that you can put on a schedule, that
|
||||||
|
encourages it to correspond to the largest-variance directions of the
|
||||||
|
input features.
|
||||||
|
|
||||||
|
Suppose the input is x, and this layer computes:
|
||||||
|
y = M x
|
||||||
|
(the bias is applied separately), then we define:
|
||||||
|
z = alpha * M^T y
|
||||||
|
where alpha is learnable; and the auxiliary loss will be:
|
||||||
|
aux_loss = normalize_mean(z - x)^2.
|
||||||
|
(normalize_mean refers to subtracting the average value per channel,
|
||||||
|
over the minibatch).
|
||||||
|
In the backward pass we compute the derivative of the auxiliary loss
|
||||||
|
and add it to the weight and bias grads, with a scale chosen such
|
||||||
|
that the extra grad's norm equals `aux_grad_scales` times the norm
|
||||||
|
of the existing grad.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
bias: bool = True,
|
||||||
|
aux_grad_scale: Optional[FloatLike] = None,
|
||||||
|
prob: FloatLike = 0.25,
|
||||||
|
initial_scale: float = 1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if aux_grad_scale is None:
|
||||||
|
aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1),
|
||||||
|
(2000.0, 0.01), (8000.0, 0.0))
|
||||||
|
|
||||||
|
self.aux_grad_scale = aux_grad_scale
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.randn(out_channels, in_channels)
|
||||||
|
* (in_channels ** -0.5) * initial_scale)
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.randn(out_channels) *
|
||||||
|
0.01 * initial_scale)
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
self.alpha = nn.Parameter(torch.tensor(1.0))
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: Tensor):
|
||||||
|
aux_grad_scale = float(self.aux_grad_scale)
|
||||||
|
if (not self.training or torch.jit.is_scripting() or
|
||||||
|
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
|
||||||
|
return torch.matmul(x, self.weight.t()) + self.bias
|
||||||
|
else:
|
||||||
|
return LinearWithAuxLossFunction.apply(x, self.weight, self.alpha,
|
||||||
|
aux_grad_scale) + self.bias
|
||||||
|
|
||||||
def ScaledLinear(*args,
|
def ScaledLinear(*args,
|
||||||
initial_scale: float = 1.0,
|
initial_scale: float = 1.0,
|
||||||
@ -417,14 +591,14 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
min_positive: float = 0.05,
|
min_positive: FloatLike = 0.05,
|
||||||
max_positive: float = 0.95,
|
max_positive: FloatLike = 0.95,
|
||||||
max_factor: float = 0.04,
|
max_factor: FloatLike = 0.04,
|
||||||
sign_gain_factor: float = 0.01,
|
sign_gain_factor: FloatLike = 0.01,
|
||||||
scale_gain_factor: float = 0.02,
|
scale_gain_factor: FloatLike = 0.02,
|
||||||
min_abs: float = 0.2,
|
min_abs: FloatLike = 0.2,
|
||||||
max_abs: float = 100.0,
|
max_abs: FloatLike = 100.0,
|
||||||
min_prob: float = 0.1,
|
min_prob: FloatLike = 0.1,
|
||||||
):
|
):
|
||||||
super(ActivationBalancer, self).__init__()
|
super(ActivationBalancer, self).__init__()
|
||||||
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
||||||
@ -453,25 +627,26 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
|
|
||||||
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
||||||
# a floor at min_prob (==0.1, by default)
|
# a floor at min_prob (==0.1, by default)
|
||||||
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
|
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||||
|
|
||||||
if random.random() < prob:
|
if random.random() < prob:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
sign_gain_factor = 0.5
|
sign_gain_factor = 0.5
|
||||||
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0:
|
||||||
sign_factor = _compute_sign_factor(x, self.channel_dim,
|
sign_factor = _compute_sign_factor(x, self.channel_dim,
|
||||||
self.min_positive, self.max_positive,
|
float(self.min_positive),
|
||||||
gain_factor=self.sign_gain_factor / prob,
|
float(self.max_positive),
|
||||||
max_factor=self.max_factor)
|
gain_factor=float(self.sign_gain_factor) / prob,
|
||||||
|
max_factor=float(self.max_factor))
|
||||||
else:
|
else:
|
||||||
sign_factor = None
|
sign_factor = None
|
||||||
|
|
||||||
|
|
||||||
scale_factor = _compute_scale_factor(x, self.channel_dim,
|
scale_factor = _compute_scale_factor(x, self.channel_dim,
|
||||||
min_abs=self.min_abs,
|
min_abs=float(self.min_abs),
|
||||||
max_abs=self.max_abs,
|
max_abs=float(self.max_abs),
|
||||||
gain_factor=self.scale_gain_factor / prob,
|
gain_factor=float(self.scale_gain_factor) / prob,
|
||||||
max_factor=self.max_factor)
|
max_factor=float(self.max_factor))
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
x, scale_factor, sign_factor, self.channel_dim,
|
x, scale_factor, sign_factor, self.channel_dim,
|
||||||
)
|
)
|
||||||
@ -519,74 +694,6 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScheduledFloat(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
|
||||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
|
||||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
|
||||||
|
|
||||||
It is a floating point value whose value changes depending on the batch count of the
|
|
||||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
|
||||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
|
||||||
first x or after the last x, we just use the first or last y value.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
|
||||||
|
|
||||||
`default` is used when self.batch_count is not set or in training or mode or in
|
|
||||||
torch.jit scripting mode.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
*args,
|
|
||||||
default: float = 0.0):
|
|
||||||
super().__init__()
|
|
||||||
# self.batch_count and self.name will be written to in the training loop.
|
|
||||||
self.batch_count = None
|
|
||||||
self.name = None
|
|
||||||
self.default = default
|
|
||||||
assert len(args) >= 1
|
|
||||||
for (x,y) in args:
|
|
||||||
assert x >= 0
|
|
||||||
for i in range(len(args) - 1):
|
|
||||||
assert args[i + 1] > args[i], args
|
|
||||||
self.schedule = args
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
|
||||||
self.schedule)
|
|
||||||
|
|
||||||
def __float__(self):
|
|
||||||
print_prob = 0.0002
|
|
||||||
def maybe_print(ans):
|
|
||||||
if random.random() < print_prob:
|
|
||||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
|
||||||
batch_count = self.batch_count
|
|
||||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
|
||||||
return float(self.default)
|
|
||||||
if batch_count <= self.schedule[0][0]:
|
|
||||||
ans = self.schedule[0][1]
|
|
||||||
maybe_print(ans)
|
|
||||||
return float(ans)
|
|
||||||
elif batch_count >= self.schedule[-1][0]:
|
|
||||||
ans = self.schedule[-1][1]
|
|
||||||
maybe_print(ans)
|
|
||||||
return float(ans)
|
|
||||||
else:
|
|
||||||
cur_x, cur_y = self.schedule[0]
|
|
||||||
for i in range(1, len(self.schedule)):
|
|
||||||
next_x, next_y = self.schedule[i]
|
|
||||||
if batch_count >= cur_x and batch_count <= next_x:
|
|
||||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
|
||||||
maybe_print(ans)
|
|
||||||
return float(ans)
|
|
||||||
cur_x, cur_y = next_x, next_y
|
|
||||||
assert False
|
|
||||||
|
|
||||||
|
|
||||||
FloatLike = Union[float, ScheduledFloat]
|
|
||||||
|
|
||||||
|
|
||||||
def _whitening_metric(x: Tensor,
|
def _whitening_metric(x: Tensor,
|
||||||
num_groups: int):
|
num_groups: int):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from scaling import (
|
|||||||
TanSwish,
|
TanSwish,
|
||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
|
LinearWithAuxLoss,
|
||||||
Whiten,
|
Whiten,
|
||||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
@ -1288,9 +1289,8 @@ class AttentionSqueeze(nn.Module):
|
|||||||
self.in_proj = nn.Linear(embed_dim, embed_dim,
|
self.in_proj = nn.Linear(embed_dim, embed_dim,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
self.to_bottleneck_proj = nn.Linear(embed_dim,
|
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
|
||||||
bottleneck_dim,
|
bottleneck_dim)
|
||||||
bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
# the main reason for this balancer is to keep the bottleneck activations in a "reasonable"
|
# the main reason for this balancer is to keep the bottleneck activations in a "reasonable"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user