check some files

This commit is contained in:
luomingshuang 2022-04-11 20:24:46 +08:00
parent 34aad74a2c
commit 05fd40ba68
6 changed files with 342 additions and 194 deletions

View File

@ -7,6 +7,9 @@ per-file-ignores =
egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501,
egs/tedlium3/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/model.py: E501,
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605

View File

@ -16,13 +16,20 @@
# limitations under the License.
import copy
from encoder_interface import EncoderInterface
import math
import warnings
from typing import Optional, Tuple, Sequence
from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
from typing import Optional, Tuple
import torch
from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
)
from torch import Tensor, nn
from icefall.utils import make_pad_mask
@ -42,6 +49,7 @@ class Conformer(EncoderInterface):
cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
self,
num_features: int,
@ -80,9 +88,8 @@ class Conformer(EncoderInterface):
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -112,8 +119,9 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
warmup=warmup) # (T, N, C)
x = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -176,18 +184,15 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0)
self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: Tensor,
@ -220,14 +225,17 @@ class ConformerEncoderLayer(nn.Module):
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
if self.training:
alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1
alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else:
alpha = 1.0
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
# multi-headed self-attention module
src_att = self.self_attn(
src,
@ -248,7 +256,7 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1-alpha) * src_orig
src = alpha * src + (1 - alpha) * src_orig
return src
@ -275,14 +283,13 @@ class ConformerEncoder(nn.Module):
)
self.num_layers = num_layers
def forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0
warmup: float = 1.0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -302,8 +309,6 @@ class ConformerEncoder(nn.Module):
"""
output = src
num_layers = len(self.layers)
for i, mod in enumerate(self.layers):
output = mod(
output,
@ -428,7 +433,9 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.25
)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
@ -621,7 +628,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
@ -653,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
@ -672,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
@ -864,9 +871,9 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto
# the correct range.
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
min_positive=0.05,
max_positive=1.0)
self.deriv_balancer1 = ActivationBalancer(
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
self.depthwise_conv = ScaledConv1d(
channels,
@ -878,9 +885,9 @@ class ConvolutionModule(nn.Module):
bias=bias,
)
self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
min_positive=0.05,
max_positive=1.0)
self.deriv_balancer2 = ActivationBalancer(
channel_dim=1, min_positive=0.05, max_positive=1.0
)
self.activation = DoubleSwish()
@ -891,7 +898,7 @@ class ConvolutionModule(nn.Module):
stride=1,
padding=0,
bias=bias,
initial_scale=0.25
initial_scale=0.25,
)
def forward(self, x: Tensor) -> Tensor:
@ -924,7 +931,6 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1)
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
@ -936,11 +942,14 @@ class Conv2dSubsampling(nn.Module):
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128) -> None:
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
) -> None:
"""
Args:
in_channels:
@ -958,34 +967,41 @@ class Conv2dSubsampling(nn.Module):
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1, out_channels=layer1_channels,
kernel_size=3, padding=1,
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=1,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels, out_channels=layer2_channels,
kernel_size=3, stride=2,
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer2_channels, out_channels=layer3_channels,
kernel_size=3, stride=2,
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
self.out = ScaledLinear(
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -1009,13 +1025,14 @@ class Conv2dSubsampling(nn.Module):
return x
if __name__ == '__main__':
if __name__ == "__main__":
feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5)
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)

View File

@ -16,15 +16,17 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import ScaledLinear
class Joiner(nn.Module):
def __init__(self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
@ -32,8 +34,10 @@ class Joiner(nn.Module):
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor,
project_input: bool = True
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
@ -52,7 +56,9 @@ class Joiner(nn.Module):
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
else:
logit = encoder_out + decoder_out

View File

@ -37,7 +37,7 @@ class Transducer(nn.Module):
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int
vocab_size: int,
):
"""
Args:
@ -48,11 +48,11 @@ class Transducer(nn.Module):
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim). It should contain
one attribute: `blank_id`.
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its
output shape is (N, T, U, vocab_size). Note that its output contains
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
@ -63,8 +63,9 @@ class Transducer(nn.Module):
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size,
initial_speed=0.5)
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
def forward(
@ -166,15 +167,14 @@ class Transducer(nn.Module):
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned,
project_input=False)
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
pruned_loss = k2.rnnt_loss_pruned(
logits=logits,

View File

@ -15,54 +15,86 @@
# limitations under the License.
import collections
from itertools import repeat
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple, Optional
def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor,
channel_dim: int,
min_positive: float, # e.g. 0.05
max_positive: float, # e.g. 0.95
max_factor: float, # e.g. 0.01
min_abs: float, # e.g. 0.2
max_abs: float, # e.g. 100.0
def forward(
ctx,
x: Tensor,
channel_dim: int,
min_positive: float, # e.g. 0.05
max_positive: float, # e.g. 0.95
max_factor: float, # e.g. 0.01
min_abs: float, # e.g. 0.2
max_abs: float, # e.g. 100.0
) -> Tensor:
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
xgt0 = x > 0
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
if min_positive != 0.0 else 0.0)
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
if max_positive != 1.0 else 0.0)
proportion_positive = torch.mean(
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
)
factor1 = (
(min_positive - proportion_positive).relu()
* (max_factor / min_positive)
if min_positive != 0.0
else 0.0
)
factor2 = (
(proportion_positive - max_positive).relu()
* (max_factor / (max_positive - 1.0))
if max_positive != 1.0
else 0.0
)
factor = factor1 + factor2
if isinstance(factor, float):
factor = torch.zeros_like(proportion_positive)
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
below_threshold = (mean_abs < min_abs)
above_threshold = (mean_abs > max_abs)
below_threshold = mean_abs < min_abs
above_threshold = mean_abs > max_abs
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
ctx.save_for_backward(
factor, xgt0, below_threshold, above_threshold
)
ctx.max_factor = max_factor
ctx.sum_dims = sum_dims
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None, None]:
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
scale_factor = (
(below_threshold.to(dtype) - above_threshold.to(dtype))
* (xgt0.to(dtype) - 0.5)
* (ctx.max_factor * 2.0)
)
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
return x_grad - neg_delta_grad, None, None, None, None, None, None
@ -95,29 +127,31 @@ class BasicNorm(torch.nn.Module):
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
"""
def __init__(self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25,
learn_eps: bool = True) -> None:
def __init__(
self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25,
learn_eps: bool = True,
) -> None:
super(BasicNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
if learn_eps:
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else:
self.register_buffer('eps', torch.tensor(eps).log().detach())
self.register_buffer("eps", torch.tensor(eps).log().detach())
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
self.eps.exp()) ** -0.5
scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
) ** -0.5
return x * scales
class ScaledLinear(nn.Linear):
"""
A modified version of nn.Linear where the parameters are scaled before
@ -143,19 +177,25 @@ class ScaledLinear(nn.Linear):
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
"""
def __init__(self, *args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
def __init__(
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self.register_parameter("bias_scale", None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
@ -172,28 +212,33 @@ class ScaledLinear(nn.Linear):
return self.weight * self.weight_scale.exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * self.bias_scale.exp())
return None if self.bias is None else self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(input, self.get_weight(),
self.get_bias())
return torch.nn.functional.linear(
input, self.get_weight(), self.get_bias()
)
class ScaledConv1d(nn.Conv1d):
# See docs for ScaledLinear
def __init__(self, *args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
def __init__(
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
@ -206,39 +251,58 @@ class ScaledConv1d(nn.Conv1d):
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self):
return self.weight * self.weight_scale.exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * self.bias_scale.exp())
return None if self.bias is None else self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
self.get_weight(), self.get_bias(), self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
self.padding, self.dilation, self.groups)
if self.padding_mode != "zeros":
return F.conv1d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
self.get_weight(),
self.get_bias(),
self.stride,
_single(0),
self.dilation,
self.groups,
)
return F.conv1d(
input,
self.get_weight(),
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
class ScaledConv2d(nn.Conv2d):
# See docs for ScaledLinear
def __init__(self, *args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
def __init__(
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
@ -251,29 +315,42 @@ class ScaledConv2d(nn.Conv2d):
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self):
return self.weight * self.weight_scale.exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * self.bias_scale.exp())
return None if self.bias is None else self.bias * self.bias_scale.exp()
def _conv_forward(self, input, weight):
F = torch.nn.functional
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.get_bias(), self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.get_bias(), self.stride,
self.padding, self.dilation, self.groups)
if self.padding_mode != "zeros":
return F.conv2d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
weight,
self.get_bias(),
self.stride,
_pair(0),
self.dilation,
self.groups,
)
return F.conv2d(
input,
weight,
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.get_weight())
class ActivationBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
@ -302,12 +379,16 @@ class ActivationBalancer(torch.nn.Module):
we allow, before we start to modify the derivatives to prevent
this.
"""
def __init__(self, channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0):
def __init__(
self,
channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0,
):
super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim
self.min_positive = min_positive
@ -317,10 +398,15 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor:
return ActivationBalancerFunction.apply(x, self.channel_dim,
self.min_positive, self.max_positive,
self.max_factor, self.min_abs,
self.max_abs)
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function):
@ -338,6 +424,7 @@ class DoubleSwishFunction(torch.autograd.Function):
= double_swish(x) * (1-s(x)) + s(x)
... so we just need to remember s(x) but not x itself.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
x = x.detach()
@ -349,18 +436,17 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
s, y = ctx.saved_tensors
return (y * (1-s) + s) * y_grad
return (y * (1 - s) + s) * y_grad
class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
that we approximate closely with x * sigmoid(x-1).
"""
return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module):
r"""This is a modified version of nn.Embedding that introduces a learnable scale
on the parameters. Note: due to how we initialize it, it's best used with
@ -443,8 +529,13 @@ class ScaledEmbedding(nn.Module):
[-0.1655, 0.9897, 0.0635]]])
"""
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
'scale_grad_by_freq', 'sparse']
__constants__ = [
"num_embeddings",
"embedding_dim",
"padding_idx",
"scale_grad_by_freq",
"sparse",
]
num_embeddings: int
embedding_dim: int
@ -453,33 +544,41 @@ class ScaledEmbedding(nn.Module):
weight: Tensor
sparse: bool
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
initial_speed: float = 1.0) -> None:
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
initial_speed: float = 1.0,
) -> None:
super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters(initial_speed)
def reset_parameters(self, initial_speed: float = 1.0) -> None:
std = 0.1 / initial_speed
nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
if self.padding_idx is not None:
with torch.no_grad():
@ -489,36 +588,53 @@ class ScaledEmbedding(nn.Module):
F = torch.nn.functional
scale = self.scale.exp()
if input.numel() < self.num_embeddings:
return F.embedding(
input, self.weight, self.padding_idx,
None, 2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq, self.sparse) * scale
return (
F.embedding(
input,
self.weight,
self.padding_idx,
None,
2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq,
self.sparse,
)
* scale
)
else:
return F.embedding(
input, self.weight * scale, self.padding_idx,
None, 2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq, self.sparse)
input,
self.weight * scale,
self.padding_idx,
None,
2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq,
self.sparse,
)
def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}, scale={scale}'
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
if self.padding_idx is not None:
s += ', padding_idx={padding_idx}'
s += ", padding_idx={padding_idx}"
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
s += ", scale_grad_by_freq={scale_grad_by_freq}"
if self.sparse is not False:
s += ', sparse=True'
s += ", sparse=True"
return s.format(**self.__dict__)
def _test_activation_balancer_sign():
channel_dim = 0
probs = torch.arange(0, 1, 0.01)
N = 1000
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach()
x.requires_grad = True
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
max_factor=0.2, min_abs=0.0)
m = ActivationBalancer(
channel_dim=0,
min_positive=0.05,
max_positive=0.95,
max_factor=0.2,
min_abs=0.0,
)
y_grad = torch.sign(torch.randn(probs.numel(), N))
@ -528,17 +644,23 @@ def _test_activation_balancer_sign():
print("_test_activation_balancer_sign: y grad = ", y_grad)
print("_test_activation_balancer_sign: x grad = ", x.grad)
def _test_activation_balancer_magnitude():
channel_dim = 0
magnitudes = torch.arange(0, 1, 0.01)
N = 1000
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
-1
)
x = x.detach()
x.requires_grad = True
m = ActivationBalancer(channel_dim=0,
min_positive=0.0, max_positive=1.0,
max_factor=0.2,
min_abs=0.2, max_abs=0.8)
m = ActivationBalancer(
channel_dim=0,
min_positive=0.0,
max_positive=1.0,
max_factor=0.2,
min_abs=0.2,
max_abs=0.8,
)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
@ -558,8 +680,8 @@ def _test_basic_norm():
y = m(x)
assert y.shape == x.shape
x_rms = (x**2).mean().sqrt()
y_rms = (y**2).mean().sqrt()
x_rms = (x ** 2).mean().sqrt()
y_rms = (y ** 2).mean().sqrt()
print("x rms = ", x_rms)
print("y rms = ", y_rms)
assert y_rms < x_rms
@ -573,7 +695,7 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x)
if __name__ == '__main__':
if __name__ == "__main__":
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()

View File

@ -29,11 +29,11 @@ from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
# use duck typing for LRScheduler since we have different possibilities, see
# our class LRScheduler.
LRSchedulerType = object
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],