mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp724' into scaled_adam_exp726
This commit is contained in:
commit
744dca1c9b
@ -463,120 +463,6 @@ class BasicNorm(torch.nn.Module):
|
|||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
|
|
||||||
class LinearWithAuxLossFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x: Tensor, weight: Tensor,
|
|
||||||
aux_grad_scale: float) -> Tensor:
|
|
||||||
"""
|
|
||||||
Returns matmul(x, weight.t()).
|
|
||||||
In the backward pass it will include an auxiliary loss based on predicting x from
|
|
||||||
matmul(y, weight).
|
|
||||||
"""
|
|
||||||
if torch.is_autocast_enabled():
|
|
||||||
x = x.to(torch.float16)
|
|
||||||
ctx.save_for_backward(x, weight)
|
|
||||||
ctx.aux_grad_scale = aux_grad_scale
|
|
||||||
return torch.matmul(x, weight.t())
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
|
|
||||||
x, weight = ctx.saved_tensors
|
|
||||||
|
|
||||||
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
|
|
||||||
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
|
|
||||||
x.reshape(-1, x.shape[-1]).to(ans_grad.dtype))
|
|
||||||
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
with torch.enable_grad():
|
|
||||||
x = x.to(weight.dtype)
|
|
||||||
x, weight = x.detach(), weight.detach()
|
|
||||||
weight.requires_grad = True
|
|
||||||
# recompute y as we need the gradient; this is easier to implement than
|
|
||||||
# saving y in the context.
|
|
||||||
y = torch.matmul(x, weight.t())
|
|
||||||
z = torch.matmul(y, weight)
|
|
||||||
# subtract mean
|
|
||||||
dims_to_mean = tuple(range(x.ndim-1))
|
|
||||||
x = x - x.mean(dim=dims_to_mean)
|
|
||||||
z = z - z.mean(dim=dims_to_mean)
|
|
||||||
# compute optimal scale on z
|
|
||||||
with torch.no_grad():
|
|
||||||
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
|
|
||||||
diff = x - alpha * z
|
|
||||||
# meansq is the loss function.
|
|
||||||
meansq = (diff ** 2).mean()
|
|
||||||
meansq.backward()
|
|
||||||
weight_aux_grad = weight.grad
|
|
||||||
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
weight_grad_norm = weight_grad.to(torch.float32).norm()
|
|
||||||
aux_grad_norm = weight_aux_grad.norm()
|
|
||||||
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
|
|
||||||
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
|
|
||||||
|
|
||||||
return x_grad, weight_grad, None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LinearWithAuxLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
A linear layer with an auxiliary loss that you can put on a schedule, that
|
|
||||||
encourages it to correspond to the largest-variance directions of the
|
|
||||||
input features.
|
|
||||||
|
|
||||||
Suppose the input is x, and this layer computes:
|
|
||||||
y = M x
|
|
||||||
(the bias is applied separately), then we define:
|
|
||||||
z = exp(alpha) * M^T y
|
|
||||||
where alpha is learnable; and the auxiliary loss will be:
|
|
||||||
aux_loss = normalize_mean(z - x)^2.
|
|
||||||
(normalize_mean refers to subtracting the average value per channel,
|
|
||||||
over the minibatch).
|
|
||||||
In the backward pass we compute the derivative of the auxiliary loss
|
|
||||||
and add it to the weight and bias grads, with a scale chosen such
|
|
||||||
that the extra grad's norm equals `aux_grad_scales` times the norm
|
|
||||||
of the existing grad.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
bias: bool = True,
|
|
||||||
aux_grad_scale: Optional[FloatLike] = None,
|
|
||||||
prob: FloatLike = 0.25,
|
|
||||||
initial_scale: float = 1.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if aux_grad_scale is None:
|
|
||||||
aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1),
|
|
||||||
(2000.0, 0.01), (8000.0, 0.0))
|
|
||||||
|
|
||||||
self.aux_grad_scale = aux_grad_scale
|
|
||||||
self.prob = prob
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels)
|
|
||||||
* (in_channels ** -0.5) * initial_scale)
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.randn(out_channels) *
|
|
||||||
0.01 * initial_scale)
|
|
||||||
else:
|
|
||||||
self.register_parameter('bias', None)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x: Tensor):
|
|
||||||
aux_grad_scale = float(self.aux_grad_scale)
|
|
||||||
if (not self.training or torch.jit.is_scripting() or
|
|
||||||
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
|
|
||||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
|
||||||
else:
|
|
||||||
ans = LinearWithAuxLossFunction.apply(x, self.weight,
|
|
||||||
aux_grad_scale)
|
|
||||||
if self.bias is not None:
|
|
||||||
ans += self.bias
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def ScaledLinear(*args,
|
def ScaledLinear(*args,
|
||||||
@ -711,12 +597,14 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
scale_gain_factor: FloatLike = 0.04,
|
scale_gain_factor: FloatLike = 0.04,
|
||||||
min_abs: FloatLike = 0.2,
|
min_abs: FloatLike = 0.2,
|
||||||
max_abs: FloatLike = 100.0,
|
max_abs: FloatLike = 100.0,
|
||||||
min_prob: FloatLike = 0.1,
|
prob: Optional[FloatLike] = None,
|
||||||
):
|
):
|
||||||
super(ActivationBalancer, self).__init__()
|
super(ActivationBalancer, self).__init__()
|
||||||
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
|
||||||
# loop.
|
|
||||||
self.batch_count = 0
|
if prob is None:
|
||||||
|
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4)
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
# actually self.num_channels is no longer needed except for an assertion.
|
# actually self.num_channels is no longer needed except for an assertion.
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
@ -726,7 +614,6 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.max_factor = max_factor
|
self.max_factor = max_factor
|
||||||
self.min_abs = min_abs
|
self.min_abs = min_abs
|
||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
self.min_prob = min_prob
|
|
||||||
self.sign_gain_factor = sign_gain_factor
|
self.sign_gain_factor = sign_gain_factor
|
||||||
self.scale_gain_factor = scale_gain_factor
|
self.scale_gain_factor = scale_gain_factor
|
||||||
|
|
||||||
@ -738,9 +625,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
if torch.jit.is_scripting() or not x.requires_grad:
|
if torch.jit.is_scripting() or not x.requires_grad:
|
||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
|
|
||||||
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
prob = float(self.prob)
|
||||||
# a floor at min_prob (==0.1, by default)
|
|
||||||
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
|
|
||||||
|
|
||||||
if random.random() < prob:
|
if random.random() < prob:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
|
|||||||
@ -36,7 +36,6 @@ from scaling import (
|
|||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
LinearWithAuxLoss,
|
|
||||||
Whiten,
|
Whiten,
|
||||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
@ -355,16 +354,11 @@ class Zipformer(EncoderInterface):
|
|||||||
|
|
||||||
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
||||||
return ScheduledFloat((0.0, x),
|
return ScheduledFloat((0.0, x),
|
||||||
(12000.0, ratio * x),
|
(20000.0, ratio * x),
|
||||||
default=x)
|
default=x)
|
||||||
|
|
||||||
def _aux_grad_scale() -> float:
|
def _balancer_schedule(min_prob: float):
|
||||||
return 0.2
|
return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
|
||||||
def _aux_grad_prob_out() -> ScheduledFloat:
|
|
||||||
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
|
|
||||||
def _aux_grad_prob_in() -> ScheduledFloat:
|
|
||||||
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.0))
|
|
||||||
#return ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -398,8 +392,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||||
# to work correctly.
|
# to work correctly.
|
||||||
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||||
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||||
nonlin_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0),
|
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
||||||
bypass_max: FloatLike = 1.0,
|
bypass_max: FloatLike = 1.0,
|
||||||
@ -410,10 +404,10 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# probability of skipping the entire layer.
|
# probability of skipping the entire layer.
|
||||||
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
||||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||||
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
|
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
|
||||||
# an additional skip probability that applies to NoninAttentionModule to stop it from
|
# an additional skip probability that applies to ConvModule to stop it from
|
||||||
# contributing too much early on.
|
# contributing too much early on.
|
||||||
self.nonlin_skip_rate = copy.deepcopy(nonlin_skip_rate)
|
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
||||||
|
|
||||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||||
# ever becoming zero.
|
# ever becoming zero.
|
||||||
@ -507,7 +501,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
# dropout rate for non-feedforward submodules
|
# dropout rate for non-feedforward submodules
|
||||||
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
|
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||||
|
|
||||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||||
if self.self_attn_weights is not None:
|
if self.self_attn_weights is not None:
|
||||||
@ -528,7 +522,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# skip the layer
|
# skip the layer
|
||||||
return src, attn_weights
|
return src, attn_weights
|
||||||
|
|
||||||
use_self_attn = (random.random() >= dynamic_skip_rate)
|
use_self_attn = (random.random() >= attention_skip_rate)
|
||||||
if use_self_attn:
|
if use_self_attn:
|
||||||
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
||||||
if random.random() < float(self.const_attention_rate):
|
if random.random() < float(self.const_attention_rate):
|
||||||
@ -541,7 +535,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||||
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or (use_self_attn and random.random() >= float(self.nonlin_skip_rate)):
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.nonlin_attention_module(src,
|
src = src + self.nonlin_attention_module(src,
|
||||||
selected_attn_weights[0:1])
|
selected_attn_weights[0:1])
|
||||||
|
|
||||||
@ -555,7 +549,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.self_attn(
|
src = src + self.self_attn(
|
||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
|
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
|
||||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
@ -750,6 +744,7 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
super(AttentionDownsample, self).__init__()
|
super(AttentionDownsample, self).__init__()
|
||||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(downsample))
|
||||||
|
|
||||||
self.name = None # will be set from training code
|
self.name = None # will be set from training code
|
||||||
self.dropout = copy.deepcopy(dropout)
|
self.dropout = copy.deepcopy(dropout)
|
||||||
@ -783,8 +778,9 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
assert src.shape[0] == d_seq_len * ds
|
assert src.shape[0] == d_seq_len * ds
|
||||||
|
|
||||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||||
# scores: (d_seq_len, downsample, batch_size)
|
# scores: (d_seq_len, downsample, batch_size, 1)
|
||||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||||
|
scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
scores = penalize_abs_values_gt(scores,
|
scores = penalize_abs_values_gt(scores,
|
||||||
limit=20.0,
|
limit=20.0,
|
||||||
@ -820,7 +816,7 @@ class SimpleUpsample(torch.nn.Module):
|
|||||||
num_channels: int,
|
num_channels: int,
|
||||||
upsample: int):
|
upsample: int):
|
||||||
super(SimpleUpsample, self).__init__()
|
super(SimpleUpsample, self).__init__()
|
||||||
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
|
self.upsample = upsample
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
src: Tensor) -> Tensor:
|
src: Tensor) -> Tensor:
|
||||||
@ -829,10 +825,9 @@ class SimpleUpsample(torch.nn.Module):
|
|||||||
Returns a tensor of shape
|
Returns a tensor of shape
|
||||||
( (seq_len*upsample), batch_size, num_channels)
|
( (seq_len*upsample), batch_size, num_channels)
|
||||||
"""
|
"""
|
||||||
upsample = self.bias.shape[0]
|
upsample = self.upsample
|
||||||
(seq_len, batch_size, num_channels) = src.shape
|
(seq_len, batch_size, num_channels) = src.shape
|
||||||
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
||||||
src = src + self.bias.unsqueeze(1)
|
|
||||||
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
||||||
return src
|
return src
|
||||||
|
|
||||||
@ -1296,15 +1291,11 @@ class AttentionSqueeze(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.bottleneck_dim = bottleneck_dim
|
self.bottleneck_dim = bottleneck_dim
|
||||||
|
|
||||||
self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim,
|
self.in_proj = nn.Linear(embed_dim, hidden_dim,
|
||||||
bias=False,
|
bias=False)
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
|
||||||
|
|
||||||
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
|
|
||||||
bottleneck_dim,
|
|
||||||
aux_grad_scale=_aux_grad_scale(),
|
|
||||||
prob=_aux_grad_prob_in())
|
|
||||||
|
|
||||||
|
self.to_bottleneck_proj = nn.Linear(embed_dim,
|
||||||
|
bottleneck_dim)
|
||||||
|
|
||||||
# bottleneck_balancer is before the actiation. Mostly, for well-trained
|
# bottleneck_balancer is before the actiation. Mostly, for well-trained
|
||||||
# instances of this module, the mean absolute values per channel are in
|
# instances of this module, the mean absolute values per channel are in
|
||||||
@ -1315,7 +1306,6 @@ class AttentionSqueeze(nn.Module):
|
|||||||
min_positive=0.2, max_positive=0.8,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.05,
|
min_abs=0.05,
|
||||||
max_abs=ScheduledFloat((0.0, 0.5), (4000.0, 1.0), default=1.0),
|
max_abs=ScheduledFloat((0.0, 0.5), (4000.0, 1.0), default=1.0),
|
||||||
min_prob=0.1,
|
|
||||||
)
|
)
|
||||||
self.bottleneck_activation = TanSwish() # in bottleneck
|
self.bottleneck_activation = TanSwish() # in bottleneck
|
||||||
self.activation = Identity() # for diagnostics
|
self.activation = Identity() # for diagnostics
|
||||||
@ -1329,13 +1319,13 @@ class AttentionSqueeze(nn.Module):
|
|||||||
hidden_dim, channel_dim=-1,
|
hidden_dim, channel_dim=-1,
|
||||||
min_positive=0.2, max_positive=0.8,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.2, max_abs=1.0,
|
min_abs=0.2, max_abs=1.0,
|
||||||
min_prob=0.05,
|
prob=_balancer_schedule(0.05),
|
||||||
)
|
)
|
||||||
self.activation_balancer = ActivationBalancer(
|
self.activation_balancer = ActivationBalancer(
|
||||||
hidden_dim, channel_dim=-1,
|
hidden_dim, channel_dim=-1,
|
||||||
min_positive=0.2, max_positive=0.8,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.2, max_abs=1.0,
|
min_abs=0.2, max_abs=1.0,
|
||||||
min_prob=0.05,
|
prob=_balancer_schedule(0.05),
|
||||||
)
|
)
|
||||||
self.activation_whiten = Whiten(num_groups=1,
|
self.activation_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
||||||
@ -1343,17 +1333,16 @@ class AttentionSqueeze(nn.Module):
|
|||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim)
|
self.from_bottleneck_proj = nn.Linear(bottleneck_dim, hidden_dim)
|
||||||
|
|
||||||
self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim,
|
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
|
||||||
aux_grad_scale=_aux_grad_scale(),
|
bias=False, initial_scale=0.05)
|
||||||
prob=_aux_grad_prob_out(),
|
|
||||||
bias=False, initial_scale=0.05)
|
|
||||||
|
|
||||||
self.out_balancer = ActivationBalancer(
|
self.out_balancer = ActivationBalancer(
|
||||||
embed_dim, channel_dim=-1,
|
embed_dim, channel_dim=-1,
|
||||||
min_positive=0.3, max_positive=0.7,
|
min_positive=0.3, max_positive=0.7,
|
||||||
min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)),
|
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
|
||||||
|
prob=0.05, # out of concern for memory usage
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1404,21 +1393,19 @@ class FeedforwardModule(nn.Module):
|
|||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: FloatLike):
|
dropout: FloatLike):
|
||||||
super(FeedforwardModule, self).__init__()
|
super(FeedforwardModule, self).__init__()
|
||||||
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
|
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
|
||||||
|
|
||||||
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
||||||
channel_dim=-1,
|
channel_dim=-1,
|
||||||
min_positive=0.3,
|
min_positive=0.3,
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
min_abs=0.75,
|
min_abs=0.75,
|
||||||
max_abs=5.0,
|
max_abs=5.0)
|
||||||
min_prob=0.25)
|
|
||||||
self.activation = SwooshL()
|
self.activation = SwooshL()
|
||||||
self.dropout = Dropout2(dropout)
|
self.dropout = Dropout2(dropout)
|
||||||
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
|
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||||
initial_scale=0.01,
|
initial_scale=0.01)
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=_whitening_schedule(7.5),
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
@ -1488,11 +1475,11 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
self.balancer2 = ActivationBalancer(
|
self.balancer2 = ActivationBalancer(
|
||||||
channels, channel_dim=-1,
|
channels, channel_dim=-1,
|
||||||
min_positive=0.3, max_positive=0.7,
|
min_positive=0.3, max_positive=0.7,
|
||||||
min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)),
|
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
|
||||||
|
prob=0.05, # out of concern for memory usage
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
attn_weights: Tensor,
|
attn_weights: Tensor,
|
||||||
@ -1560,9 +1547,8 @@ class ConvolutionModule(nn.Module):
|
|||||||
bottleneck_dim = channels
|
bottleneck_dim = channels
|
||||||
|
|
||||||
|
|
||||||
self.in_proj = LinearWithAuxLoss(
|
self.in_proj = nn.Linear(
|
||||||
channels, 2 * bottleneck_dim,
|
channels, 2 * bottleneck_dim,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1618,9 +1604,8 @@ class ConvolutionModule(nn.Module):
|
|||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
self.out_proj = LinearWithAuxLoss(
|
self.out_proj = ScaledLinear(
|
||||||
bottleneck_dim, channels,
|
bottleneck_dim, channels,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
|
|
||||||
initial_scale=0.05,
|
initial_scale=0.05,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1834,8 +1819,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
self.scale_max = 1.0
|
self.scale_max = 1.0
|
||||||
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
|
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
|
||||||
|
|
||||||
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
self.out = nn.Linear(out_height * layer3_channels, out_channels)
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
|
||||||
|
|
||||||
self.dropout = Dropout2(dropout)
|
self.dropout = Dropout2(dropout)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user