Remove scale_speed, make swish deriv more efficient.

This commit is contained in:
Daniel Povey 2022-03-18 16:31:25 +08:00
parent cbe6b175d1
commit 6769087d70
3 changed files with 181 additions and 185 deletions

View File

@ -410,7 +410,6 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
dropout: float = 0.0, dropout: float = 0.0,
scale_speed: float = 5.0
) -> None: ) -> None:
super(RelPositionMultiheadAttention, self).__init__() super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -430,16 +429,15 @@ class RelPositionMultiheadAttention(nn.Module):
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.scale_speed = scale_speed
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters() self._reset_parameters()
def _pos_bias_u(self): def _pos_bias_u(self):
return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() return self.pos_bias_u * self.pos_bias_u_scale.exp()
def _pos_bias_v(self): def _pos_bias_v(self):
return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None: def _reset_parameters(self) -> None:
nn.init.normal_(self.pos_bias_u, std=0.05) nn.init.normal_(self.pos_bias_u, std=0.05)

View File

@ -19,7 +19,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from typing import Optional from typing import Optional
from scaling import ScaledConv1d, ScaledLinear from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding
class Decoder(nn.Module): class Decoder(nn.Module):
@ -103,139 +103,3 @@ class Decoder(nn.Module):
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.output_linear(F.relu(embedding_out)) embedding_out = self.output_linear(F.relu(embedding_out))
return embedding_out return embedding_out
class ScaledEmbedding(nn.Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
See Notes for more details regarding sparse gradients.
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)`
Shape:
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
.. note::
Keep in mind that only a limited number of optimizers support
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
.. note::
With :attr:`padding_idx` set, the embedding vector at
:attr:`padding_idx` is initialized to all zeros. However, note that this
vector can be modified afterwards, e.g., using a customized
initialization method, and thus changing the vector used to pad the
output. The gradient for this vector from :class:`~torch.nn.Embedding`
is always zero.
Examples::
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
"""
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
'scale_grad_by_freq', 'sparse']
num_embeddings: int
embedding_dim: int
padding_idx: int
scale_grad_by_freq: bool
weight: Tensor
sparse: bool
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
scale_speed: float = 5.0) -> None:
super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq
self.scale_speed = scale_speed
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.05)
nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed)
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
scale = (self.scale * self.scale_speed).exp()
if input.numel() < self.num_embeddings:
return F.embedding(
input, self.weight, self.padding_idx,
None, 2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq, self.sparse) * scale
else:
return F.embedding(
input, self.weight * scale, self.padding_idx,
None, 2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq, self.sparse)
def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}'
if self.padding_idx is not None:
s += ', padding_idx={padding_idx}'
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
if self.sparse is not False:
s += ', sparse=True'
return s.format(**self.__dict__)

View File

@ -18,7 +18,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Tuple from typing import Tuple, Optional
@ -94,31 +94,25 @@ class BasicNorm(torch.nn.Module):
to indicate the connection with conventional LayerNorm. to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value. at the initial value.
eps_speed: a constant that determines how fast "eps" learns;
with Adam and variants, this should probably be >= 1,
e.g. 5.0. For SGD and variants, probably a value less than one,
like 0.1, would be suitable, to prevent instability.
""" """
def __init__(self, def __init__(self,
num_channels: int, num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation. channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25, eps: float = 0.25,
learn_eps: bool = True, learn_eps: bool = True) -> None:
eps_speed: float = 5.0):
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.eps_speed = eps_speed
if learn_eps: if learn_eps:
self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else: else:
self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) self.register_buffer('eps', torch.tensor(eps).log().detach())
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
(self.eps * self.eps_speed).exp()) ** -0.5 self.eps.exp()) ** -0.5
return x * scales return x * scales
@ -128,16 +122,13 @@ class ScaledLinear(nn.Linear):
""" """
A modified version of nn.Linear where the parameters are scaled before A modified version of nn.Linear where the parameters are scaled before
use, via: use, via:
weight = self.weight * (self.weight_scale * self.scale_speed).exp() weight = self.weight * self.weight_scale.exp()
bias = self.bias * (self.bias_scale * self.scale_speed).exp() bias = self.bias * self.bias_scale.exp()
Args: Args:
Accepts the standard args and kwargs that nn.Linear accepts Accepts the standard args and kwargs that nn.Linear accepts
e.g. in_features, out_features, bias=False. e.g. in_features, out_features, bias=False.
scale_speed: a factor that affects how fast the weight_scale
and bias_scale learn; this value is suitable for Adam-type
optimizers.
initial_scale: you can override this if you want to increase initial_scale: you can override this if you want to increase
or decrease the initial magnitude of the module's output or decrease the initial magnitude of the module's output
(affects the initialization of weight_scale and bias_scale). (affects the initialization of weight_scale and bias_scale).
@ -149,13 +140,11 @@ class ScaledLinear(nn.Linear):
may be larger than optimal. may be larger than optimal.
""" """
def __init__(self, *args, def __init__(self, *args,
scale_speed: float = 5.0,
initial_scale: float = 1.0, initial_scale: float = 1.0,
**kwargs): **kwargs):
super(ScaledLinear, self).__init__(*args, **kwargs) super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = (torch.tensor(initial_scale).log() / scale_speed) initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
self.scale_speed = scale_speed
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
@ -172,14 +161,14 @@ class ScaledLinear(nn.Linear):
fan_in = self.weight.shape[1] * self.weight[0][0].numel() fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in) scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad(): with torch.no_grad():
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self): def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return (None if self.bias is None else return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp()) self.bias * self.bias_scale.exp())
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(input, self.get_weight(), return torch.nn.functional.linear(input, self.get_weight(),
@ -187,11 +176,10 @@ class ScaledLinear(nn.Linear):
class ScaledConv1d(nn.Conv1d): class ScaledConv1d(nn.Conv1d):
def __init__(self, *args, scale_speed = 5.0, def __init__(self, *args,
initial_scale=1.0, **kwargs): initial_scale=1.0, **kwargs):
super(ScaledConv1d, self).__init__(*args, **kwargs) super(ScaledConv1d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed initial_scale = torch.tensor(initial_scale).log()
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
@ -208,15 +196,15 @@ class ScaledConv1d(nn.Conv1d):
fan_in = self.weight.shape[1] * self.weight[0][0].numel() fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in) scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad(): with torch.no_grad():
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self): def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return (None if self.bias is None else return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp()) self.bias * self.bias_scale.exp())
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional F = torch.nn.functional
@ -230,10 +218,9 @@ class ScaledConv1d(nn.Conv1d):
class ScaledConv2d(nn.Conv2d): class ScaledConv2d(nn.Conv2d):
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): def __init__(self, *args, initial_scale=1.0, **kwargs):
super(ScaledConv2d, self).__init__(*args, **kwargs) super(ScaledConv2d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed initial_scale = torch.tensor(initial_scale).log()
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
@ -250,15 +237,15 @@ class ScaledConv2d(nn.Conv2d):
fan_in = self.weight.shape[1] * self.weight[0][0].numel() fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in) scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad(): with torch.no_grad():
self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self): def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return (None if self.bias is None else return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp()) self.bias * self.bias_scale.exp())
def _conv_forward(self, input, weight): def _conv_forward(self, input, weight):
F = torch.nn.functional F = torch.nn.functional
@ -323,6 +310,16 @@ class ActivationBalancer(torch.nn.Module):
self.max_factor, self.min_abs, self.max_factor, self.min_abs,
self.max_abs) self.max_abs)
# deriv of double_swish:
# double_swish(x) = x * torch.sigmoid(x-1) [this is a definition, originally
# motivated by its similarity to swish(swish(x),
# where swish(x) = x *sigmoid(x)].
# double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
# double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
# Now, s'(x) = s(x) * (1-s(x)).
# double_swish'(x) = x * s'(x) + s(x).
# = x * s(x) * (1-s(x)) + s(x).
# = double_swish(x) * (1-s(x)) + s(x)
def _double_swish(x: Tensor) -> Tensor: def _double_swish(x: Tensor) -> Tensor:
# double-swish, implemented/approximated as offset-swish # double-swish, implemented/approximated as offset-swish
@ -331,18 +328,16 @@ def _double_swish(x: Tensor) -> Tensor:
class DoubleSwishFunction(torch.autograd.Function): class DoubleSwishFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor) -> Tensor: def forward(ctx, x: Tensor) -> Tensor:
ctx.save_for_backward(x.detach()) x = x.detach()
return _double_swish(x) s = torch.sigmoid(x - 1.0)
y = x * s
ctx.save_for_backward(s, y)
return y
@staticmethod @staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor: def backward(ctx, y_grad: Tensor) -> Tensor:
# TODO: can make this more efficient. s, y = ctx.saved_tensors
x, = ctx.saved_tensors return (y * (1-s) + s) * y_grad
x.requires_grad = True
with torch.enable_grad():
y = _double_swish(x)
y.backward(gradient=y_grad)
return x.grad
class DoubleSwish(torch.nn.Module): class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -353,6 +348,140 @@ class DoubleSwish(torch.nn.Module):
class ScaledEmbedding(nn.Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
See Notes for more details regarding sparse gradients.
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)`
Shape:
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
.. note::
Keep in mind that only a limited number of optimizers support
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
.. note::
With :attr:`padding_idx` set, the embedding vector at
:attr:`padding_idx` is initialized to all zeros. However, note that this
vector can be modified afterwards, e.g., using a customized
initialization method, and thus changing the vector used to pad the
output. The gradient for this vector from :class:`~torch.nn.Embedding`
is always zero.
Examples::
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
"""
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
'scale_grad_by_freq', 'sparse']
num_embeddings: int
embedding_dim: int
padding_idx: int
scale_grad_by_freq: bool
weight: Tensor
sparse: bool
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> None:
super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.05)
nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log())
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
scale = self.scale.exp()
if input.numel() < self.num_embeddings:
return F.embedding(
input, self.weight, self.padding_idx,
None, 2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq, self.sparse) * scale
else:
return F.embedding(
input, self.weight * scale, self.padding_idx,
None, 2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq, self.sparse)
def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}, scale={scale}'
if self.padding_idx is not None:
s += ', padding_idx={padding_idx}'
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
if self.sparse is not False:
s += ', sparse=True'
return s.format(**self.__dict__)
def _test_activation_balancer_sign(): def _test_activation_balancer_sign():
channel_dim = 0 channel_dim = 0
probs = torch.arange(0, 1, 0.01) probs = torch.arange(0, 1, 0.01)
@ -409,10 +538,15 @@ def _test_basic_norm():
assert y_rms > 0.5 * x_rms assert y_rms > 0.5 * x_rms
def _test_double_swish_deriv():
x = torch.randn(10, 12, dtype=torch.double) * 0.5
x.requires_grad = True
m = DoubleSwish()
torch.autograd.gradcheck(m, x)
if __name__ == '__main__': if __name__ == '__main__':
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_basic_norm() _test_basic_norm()
_test_double_swish_deriv()