mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp7c' into scaled_adam_exp11c
This commit is contained in:
commit
03a77f8ae5
@ -173,7 +173,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
ActivationBalancer(channel_dim=-1, max_abs=10.0),
|
ActivationBalancer(dim_feedforward,
|
||||||
|
channel_dim=-1, max_abs=10.0),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(dim_feedforward, d_model,
|
||||||
@ -182,7 +183,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
ActivationBalancer(channel_dim=-1, max_abs=10.0),
|
ActivationBalancer(dim_feedforward,
|
||||||
|
channel_dim=-1, max_abs=10.0),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(dim_feedforward, d_model,
|
||||||
@ -196,7 +198,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
self.balancer = ActivationBalancer(
|
self.balancer = ActivationBalancer(
|
||||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
d_model, channel_dim=-1,
|
||||||
|
min_positive=0.45, max_positive=0.55,
|
||||||
|
max_abs=6.0,
|
||||||
|
max_var_per_eig=0.2,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
@ -247,8 +252,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
src,
|
|
||||||
src,
|
|
||||||
src,
|
src,
|
||||||
pos_emb=pos_emb,
|
pos_emb=pos_emb,
|
||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
@ -464,8 +467,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
|
self.in_balancer = ActivationBalancer(3 * embed_dim,
|
||||||
self.proj_balancer = ActivationBalancer(channel_dim=-1, max_abs=10.0)
|
channel_dim=-1, max_abs=5.0,
|
||||||
|
max_var_per_eig=0.2)
|
||||||
|
self.proj_balancer = ActivationBalancer(embed_dim,
|
||||||
|
channel_dim=-1, max_abs=10.0,
|
||||||
|
min_positive=0.0, max_positive=1.0)
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
||||||
)
|
)
|
||||||
@ -484,9 +491,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: Tensor,
|
x: Tensor,
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
@ -494,7 +499,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
query, key, value: map a query and a set of key-value pairs to an output.
|
x: input to be projected to query, key, value
|
||||||
pos_emb: Positional embedding tensor
|
pos_emb: Positional embedding tensor
|
||||||
key_padding_mask: if provided, specified padding elements in the key will
|
key_padding_mask: if provided, specified padding elements in the key will
|
||||||
be ignored by the attention. When given a binary mask and a value is True,
|
be ignored by the attention. When given a binary mask and a value is True,
|
||||||
@ -507,11 +512,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- Inputs:
|
- Inputs:
|
||||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
- x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
the embedding dimension.
|
|
||||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
the embedding dimension.
|
||||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
the embedding dimension.
|
the embedding dimension.
|
||||||
@ -534,9 +535,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
L is the target sequence length, S is the source sequence length.
|
L is the target sequence length, S is the source sequence length.
|
||||||
"""
|
"""
|
||||||
return self.multi_head_attention_forward(
|
return self.multi_head_attention_forward(
|
||||||
query,
|
self.in_balancer(self.in_proj(x)),
|
||||||
key,
|
|
||||||
value,
|
|
||||||
pos_emb,
|
pos_emb,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@ -578,11 +577,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
def multi_head_attention_forward(
|
def multi_head_attention_forward(
|
||||||
self,
|
self,
|
||||||
query: Tensor,
|
x: Tensor,
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
embed_dim_to_check: int,
|
embed_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
in_proj_weight: Tensor,
|
in_proj_weight: Tensor,
|
||||||
in_proj_bias: Tensor,
|
in_proj_bias: Tensor,
|
||||||
@ -598,7 +595,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
query, key, value: map a query and a set of key-value pairs to an output.
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
pos_emb: Positional embedding tensor
|
pos_emb: Positional embedding tensor
|
||||||
embed_dim_to_check: total dimension of the model.
|
embed_dim: total dimension of the model.
|
||||||
num_heads: parallel attention heads.
|
num_heads: parallel attention heads.
|
||||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||||
dropout_p: probability of an element to be zeroed.
|
dropout_p: probability of an element to be zeroed.
|
||||||
@ -640,9 +637,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
L is the target sequence length, S is the source sequence length.
|
L is the target sequence length, S is the source sequence length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
tgt_len, bsz, _ = x.size()
|
||||||
assert embed_dim == embed_dim_to_check
|
|
||||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
|
||||||
|
|
||||||
head_dim = embed_dim // num_heads
|
head_dim = embed_dim // num_heads
|
||||||
assert (
|
assert (
|
||||||
@ -651,62 +646,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
scaling = float(head_dim) ** -0.5
|
scaling = float(head_dim) ** -0.5
|
||||||
|
|
||||||
def linear(x, w, b):
|
|
||||||
return self.in_balancer(nn.functional.linear(x, w, b))
|
|
||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
# self-attention
|
||||||
# self-attention
|
q, k, v = x.chunk(3, dim=-1)
|
||||||
q, k, v = linear(
|
|
||||||
query, in_proj_weight, in_proj_bias
|
|
||||||
).chunk(3, dim=-1)
|
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
|
||||||
# encoder-decoder attention
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = 0
|
|
||||||
_end = embed_dim
|
|
||||||
_w = in_proj_weight[_start:_end, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:_end]
|
|
||||||
q = linear(query, _w, _b)
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = embed_dim
|
|
||||||
_end = None
|
|
||||||
_w = in_proj_weight[_start:, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:]
|
|
||||||
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = 0
|
|
||||||
_end = embed_dim
|
|
||||||
_w = in_proj_weight[_start:_end, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:_end]
|
|
||||||
q = linear(query, _w, _b)
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = embed_dim
|
|
||||||
_end = embed_dim * 2
|
|
||||||
_w = in_proj_weight[_start:_end, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:_end]
|
|
||||||
k = linear(key, _w, _b)
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = embed_dim * 2
|
|
||||||
_end = None
|
|
||||||
_w = in_proj_weight[_start:, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:]
|
|
||||||
v = linear(value, _w, _b)
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
@ -726,15 +669,15 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if attn_mask.dim() == 2:
|
if attn_mask.dim() == 2:
|
||||||
attn_mask = attn_mask.unsqueeze(0)
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
if list(attn_mask.size()) != [1, tgt_len, tgt_len]:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"The size of the 2D attn_mask is not correct."
|
"The size of the 2D attn_mask is not correct."
|
||||||
)
|
)
|
||||||
elif attn_mask.dim() == 3:
|
elif attn_mask.dim() == 3:
|
||||||
if list(attn_mask.size()) != [
|
if list(attn_mask.size()) != [
|
||||||
bsz * num_heads,
|
bsz * num_heads,
|
||||||
query.size(0),
|
tgt_len,
|
||||||
key.size(0),
|
tgt_len,
|
||||||
]:
|
]:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"The size of the 3D attn_mask is not correct."
|
"The size of the 3D attn_mask is not correct."
|
||||||
@ -900,6 +843,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
# it will be in a better position to start learning something, i.e. to latch onto
|
||||||
# the correct range.
|
# the correct range.
|
||||||
self.deriv_balancer1 = ActivationBalancer(
|
self.deriv_balancer1 = ActivationBalancer(
|
||||||
|
2 * channels,
|
||||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -914,7 +858,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.deriv_balancer2 = ActivationBalancer(
|
self.deriv_balancer2 = ActivationBalancer(
|
||||||
channel_dim=1, min_positive=0.05, max_positive=1.0
|
channels, channel_dim=1, min_positive=0.05, max_positive=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.activation = DoubleSwish()
|
self.activation = DoubleSwish()
|
||||||
@ -1000,7 +944,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(layer1_channels,
|
||||||
|
channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=layer1_channels,
|
in_channels=layer1_channels,
|
||||||
@ -1008,7 +953,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(layer2_channels,
|
||||||
|
channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
in_channels=layer2_channels,
|
in_channels=layer2_channels,
|
||||||
@ -1016,7 +962,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(layer3_channels,
|
||||||
|
channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
||||||
@ -1027,6 +974,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||||
# constrain median of output to be close to zero.
|
# constrain median of output to be close to zero.
|
||||||
self.out_balancer = ActivationBalancer(
|
self.out_balancer = ActivationBalancer(
|
||||||
|
out_channels,
|
||||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -254,7 +254,7 @@ class ScaledAdam(Optimizer):
|
|||||||
if ans < 1.0:
|
if ans < 1.0:
|
||||||
state["num_clipped"] += 1
|
state["num_clipped"] += 1
|
||||||
if ans < 0.1:
|
if ans < 0.1:
|
||||||
logging.warn("Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -114,6 +114,108 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
|||||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def find_direction_coeffs(x: Tensor,
|
||||||
|
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Figure out (an approximation to) the proportion of the variance of a set of
|
||||||
|
feature vectors that can be attributed to the top eigen-direction.
|
||||||
|
Args:
|
||||||
|
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
||||||
|
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
||||||
|
of the top eigen-direction, or a random direction if this is the first
|
||||||
|
iteration. Does not have to be normalized, but should be nonzero.
|
||||||
|
|
||||||
|
Returns: (cur_direction, coeffs), where:
|
||||||
|
cur_direction: a Tensor of shape (num_channels,) that is the current
|
||||||
|
estimate of the top eigen-direction.
|
||||||
|
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
||||||
|
approximately minimizes, (x - coeffs * cur_direction).norm()
|
||||||
|
"""
|
||||||
|
(num_frames, num_channels) = x.shape
|
||||||
|
assert num_channels > 1 and num_frames > 1
|
||||||
|
|
||||||
|
assert prev_direction.shape == (num_channels,)
|
||||||
|
|
||||||
|
# `coeffs` are the coefficients of `prev_direction` in x.
|
||||||
|
# actually represent the coeffs up to a constant positive factor.
|
||||||
|
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
||||||
|
|
||||||
|
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
|
||||||
|
|
||||||
|
return cur_direction, coeffs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MaxEigLimiterFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
x: Tensor,
|
||||||
|
direction: Tensor,
|
||||||
|
channel_dim: int,
|
||||||
|
prob: float,
|
||||||
|
subtract_mean: bool,
|
||||||
|
max_variance_proportion: float,
|
||||||
|
grad_scale: float) -> Tuple[Tensor, Tensor]:
|
||||||
|
if random.random() > prob:
|
||||||
|
return x, direction
|
||||||
|
eps = 1.0e-20
|
||||||
|
num_channels = x.shape[channel_dim]
|
||||||
|
assert max_variance_proportion > 1.0 / num_channels
|
||||||
|
orig_x = x
|
||||||
|
x = x.transpose(channel_dim, -1).reshape(-1, num_channels)
|
||||||
|
if subtract_mean:
|
||||||
|
x = x - x.mean(dim=0)
|
||||||
|
new_direction, coeffs = find_direction_coeffs(x, direction)
|
||||||
|
x_var = (x**2).mean()
|
||||||
|
x_residual = x - coeffs * new_direction
|
||||||
|
x_residual_var = (x_residual**2).mean()
|
||||||
|
# `variance_proportion` is the proportion of the variance accounted for
|
||||||
|
# by the top eigen-direction.
|
||||||
|
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
||||||
|
|
||||||
|
ans_direction = direction + new_direction # ensure nonzero even if x == 0
|
||||||
|
ans_direction = ans_direction / ans_direction.norm()
|
||||||
|
|
||||||
|
if random.random() < 0.001:
|
||||||
|
logging.info(f"variance_proportion = {variance_proportion.item()}")
|
||||||
|
|
||||||
|
# Caution: this causes a CUDA sync, which is not ideal.
|
||||||
|
if variance_proportion >= max_variance_proportion:
|
||||||
|
ctx.channel_dim = channel_dim
|
||||||
|
ctx.subtract_mean = subtract_mean
|
||||||
|
ctx.grad_scale = grad_scale
|
||||||
|
ctx.save_for_backward(orig_x.detach(),
|
||||||
|
coeffs.detach(),
|
||||||
|
new_direction.detach())
|
||||||
|
|
||||||
|
return orig_x, ans_direction
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, x_grad, *args):
|
||||||
|
# the *args is all the other derivs, which should be None or zero.
|
||||||
|
if not hasattr(ctx, 'channel_dim'):
|
||||||
|
# the top eig's proportion of the variance was below the threshold.
|
||||||
|
return x_grad, None, None, None, None, None, None
|
||||||
|
with torch.enable_grad():
|
||||||
|
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
||||||
|
x_orig.requires_grad = True
|
||||||
|
num_channels = x_orig.shape[ctx.channel_dim]
|
||||||
|
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
||||||
|
new_direction.requires_grad = False
|
||||||
|
if ctx.subtract_mean:
|
||||||
|
x = x - x.mean(dim=0)
|
||||||
|
x_var = (x ** 2).mean()
|
||||||
|
x_residual = x - coeffs * new_direction
|
||||||
|
x_residual_var = (x_residual ** 2).mean()
|
||||||
|
# `variance_proportion` is the proportion of the variance accounted for
|
||||||
|
# by the top eigen-direction. This is to be minimized.
|
||||||
|
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
||||||
|
variance_proportion.backward()
|
||||||
|
x_orig_grad = x_orig.grad
|
||||||
|
x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20)
|
||||||
|
return x_grad + x_extra_grad.detach(), None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class BasicNorm(torch.nn.Module):
|
class BasicNorm(torch.nn.Module):
|
||||||
@ -236,6 +338,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
num_channels: the number of channels
|
||||||
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||||
min_positive: the minimum, per channel, of the proportion of the time
|
min_positive: the minimum, per channel, of the proportion of the time
|
||||||
@ -252,29 +355,60 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
max_abs: the maximum average-absolute-value difference from the mean
|
max_abs: the maximum average-absolute-value difference from the mean
|
||||||
value per channel, which we allow, before we start to modify
|
value per channel, which we allow, before we start to modify
|
||||||
the derivatives to prevent this.
|
the derivatives to prevent this.
|
||||||
|
max_var_per_eig: the maximum proportion of the variance of the
|
||||||
|
features/channels, after mean subtraction, that can come from
|
||||||
|
any given eigenvalue.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
num_channels: int,
|
||||||
channel_dim: int,
|
channel_dim: int,
|
||||||
min_positive: float = 0.05,
|
min_positive: float = 0.05,
|
||||||
max_positive: float = 0.95,
|
max_positive: float = 0.95,
|
||||||
max_factor: float = 0.01,
|
max_factor: float = 0.01,
|
||||||
min_abs: float = 0.2,
|
min_abs: float = 0.2,
|
||||||
max_abs: float = 100.0,
|
max_abs: float = 100.0,
|
||||||
|
max_var_per_eig: float = 0.0,
|
||||||
):
|
):
|
||||||
super(ActivationBalancer, self).__init__()
|
super(ActivationBalancer, self).__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.min_positive = min_positive
|
self.min_positive = min_positive
|
||||||
self.max_positive = max_positive
|
self.max_positive = max_positive
|
||||||
self.max_factor = max_factor
|
self.max_factor = max_factor
|
||||||
self.min_abs = min_abs
|
self.min_abs = min_abs
|
||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
|
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
||||||
|
self.max_var_per_eig = max_var_per_eig
|
||||||
|
if max_var_per_eig > 0.0:
|
||||||
|
with torch.no_grad():
|
||||||
|
# arbitrary.. would use randn() but want to leave the rest of the model's
|
||||||
|
# random parameters unchanged for comparison
|
||||||
|
direction = torch.arange(num_channels).to(torch.float)
|
||||||
|
direction = direction / direction.norm()
|
||||||
|
self.register_buffer('max_eig_direction', direction)
|
||||||
|
else:
|
||||||
|
self.max_eig_direction = None
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
if self.max_var_per_eig > 0:
|
||||||
|
max_eig_prob = 0.25
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
x, new_direction = MaxEigLimiterFunction.apply(
|
||||||
|
x, self.max_eig_direction,
|
||||||
|
self.channel_dim,
|
||||||
|
max_eig_prob,
|
||||||
|
True, # subtract_mean
|
||||||
|
self.max_var_per_eig,
|
||||||
|
self.max_factor / max_eig_prob,
|
||||||
|
)
|
||||||
|
self.max_eig_direction[:] = new_direction.detach()
|
||||||
|
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
x,
|
x,
|
||||||
self.channel_dim,
|
self.channel_dim,
|
||||||
@ -326,6 +460,35 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_max_eig_limiter():
|
||||||
|
|
||||||
|
for proportion in [0.1, 0.5, 10.0]:
|
||||||
|
logging.info(f"proportion = {proportion}")
|
||||||
|
x = torch.randn(100, 128)
|
||||||
|
direction = torch.randn(128)
|
||||||
|
coeffs = torch.randn(100, 1)
|
||||||
|
x += proportion * direction * coeffs
|
||||||
|
|
||||||
|
x.requires_grad = True
|
||||||
|
|
||||||
|
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
|
||||||
|
1, # channel_dim
|
||||||
|
1.0, # prob
|
||||||
|
True, # subtract_mean
|
||||||
|
0.5, # max_variance_proportion
|
||||||
|
0.1, # grad_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
cosine = (new_direction * direction).sum() / (new_direction.norm() * direction.norm())
|
||||||
|
logging.info(f"Direction cosine = {cosine}")
|
||||||
|
|
||||||
|
y_grad = torch.randn_like(x)
|
||||||
|
y.backward(gradient=y_grad)
|
||||||
|
|
||||||
|
if proportion < 0.2:
|
||||||
|
assert torch.allclose(x.grad, y_grad)
|
||||||
|
elif proportion > 1.0:
|
||||||
|
assert not torch.allclose(x.grad, y_grad)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -336,6 +499,7 @@ def _test_activation_balancer_sign():
|
|||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = ActivationBalancer(
|
m = ActivationBalancer(
|
||||||
|
probs.numel(),
|
||||||
channel_dim=0,
|
channel_dim=0,
|
||||||
min_positive=0.05,
|
min_positive=0.05,
|
||||||
max_positive=0.98,
|
max_positive=0.98,
|
||||||
@ -361,6 +525,7 @@ def _test_activation_balancer_magnitude():
|
|||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = ActivationBalancer(
|
m = ActivationBalancer(
|
||||||
|
magnitudes.numel(),
|
||||||
channel_dim=0,
|
channel_dim=0,
|
||||||
min_positive=0.0,
|
min_positive=0.0,
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
@ -402,10 +567,12 @@ def _test_double_swish_deriv():
|
|||||||
torch.autograd.gradcheck(m, x)
|
torch.autograd.gradcheck(m, x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
_test_max_eig_limiter()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user