check some files

This commit is contained in:
luomingshuang 2022-04-11 20:24:46 +08:00
parent 34aad74a2c
commit 05fd40ba68
6 changed files with 342 additions and 194 deletions

View File

@ -7,6 +7,9 @@ per-file-ignores =
egs/librispeech/ASR/*/conformer.py: E501, egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501, egs/aishell/ASR/*/conformer.py: E501,
egs/tedlium3/ASR/*/conformer.py: E501, egs/tedlium3/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/model.py: E501,
# invalid escape sequence (cause by tex formular), W605 # invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605 icefall/utils.py: E501, W605

View File

@ -16,13 +16,20 @@
# limitations under the License. # limitations under the License.
import copy import copy
from encoder_interface import EncoderInterface
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple
from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
import torch import torch
from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
)
from torch import Tensor, nn from torch import Tensor, nn
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
@ -42,6 +49,7 @@ class Conformer(EncoderInterface):
cnn_module_kernel (int): Kernel size of convolution module cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend. vgg_frontend (bool): whether to use vgg frontend.
""" """
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
@ -80,9 +88,8 @@ class Conformer(EncoderInterface):
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -112,8 +119,9 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask, x = self.encoder(
warmup=warmup) # (T, N, C) x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -176,18 +184,15 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_final = BasicNorm(d_model) self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median). # try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer(channel_dim=-1, self.balancer = ActivationBalancer(
min_positive=0.45, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
max_positive=0.55, )
max_abs=6.0)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
@ -220,14 +225,17 @@ class ConformerEncoderLayer(nn.Module):
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it. # completely bypass it.
if self.training: if self.training:
alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1 alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else: else:
alpha = 1.0 alpha = 1.0
# macaron style feed forward module # macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))
# multi-headed self-attention module # multi-headed self-attention module
src_att = self.self_attn( src_att = self.self_attn(
src, src,
@ -248,7 +256,7 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
if alpha != 1.0: if alpha != 1.0:
src = alpha * src + (1-alpha) * src_orig src = alpha * src + (1 - alpha) * src_orig
return src return src
@ -275,14 +283,13 @@ class ConformerEncoder(nn.Module):
) )
self.num_layers = num_layers self.num_layers = num_layers
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0 warmup: float = 1.0,
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -302,8 +309,6 @@ class ConformerEncoder(nn.Module):
""" """
output = src output = src
num_layers = len(self.layers)
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
output = mod( output = mod(
output, output,
@ -428,7 +433,9 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.25
)
# linear transformation for positional encoding. # linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
@ -621,7 +628,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -653,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:_end] _b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b) q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
_start = embed_dim _start = embed_dim
@ -672,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:] _b = _b[_start:]
v = nn.functional.linear(value, _w, _b) v = nn.functional.linear(value, _w, _b)
if attn_mask is not None: if attn_mask is not None:
assert ( assert (
attn_mask.dtype == torch.float32 attn_mask.dtype == torch.float32
@ -864,9 +871,9 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto # it will be in a better position to start learning something, i.e. to latch onto
# the correct range. # the correct range.
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, self.deriv_balancer1 = ActivationBalancer(
min_positive=0.05, channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
max_positive=1.0) )
self.depthwise_conv = ScaledConv1d( self.depthwise_conv = ScaledConv1d(
channels, channels,
@ -878,9 +885,9 @@ class ConvolutionModule(nn.Module):
bias=bias, bias=bias,
) )
self.deriv_balancer2 = ActivationBalancer(channel_dim=1, self.deriv_balancer2 = ActivationBalancer(
min_positive=0.05, channel_dim=1, min_positive=0.05, max_positive=1.0
max_positive=1.0) )
self.activation = DoubleSwish() self.activation = DoubleSwish()
@ -891,7 +898,7 @@ class ConvolutionModule(nn.Module):
stride=1, stride=1,
padding=0, padding=0,
bias=bias, bias=bias,
initial_scale=0.25 initial_scale=0.25,
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -924,7 +931,6 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1) return x.permute(2, 0, 1)
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length). """Convolutional 2D subsampling (to 1/4 length).
@ -936,11 +942,14 @@ class Conv2dSubsampling(nn.Module):
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
""" """
def __init__(self, in_channels: int, def __init__(
out_channels: int, self,
layer1_channels: int = 8, in_channels: int,
layer2_channels: int = 32, out_channels: int,
layer3_channels: int = 128) -> None: layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
) -> None:
""" """
Args: Args:
in_channels: in_channels:
@ -958,34 +967,41 @@ class Conv2dSubsampling(nn.Module):
self.conv = nn.Sequential( self.conv = nn.Sequential(
ScaledConv2d( ScaledConv2d(
in_channels=1, out_channels=layer1_channels, in_channels=1,
kernel_size=3, padding=1, out_channels=layer1_channels,
kernel_size=3,
padding=1,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),
ScaledConv2d( ScaledConv2d(
in_channels=layer1_channels, out_channels=layer2_channels, in_channels=layer1_channels,
kernel_size=3, stride=2, out_channels=layer2_channels,
kernel_size=3,
stride=2,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),
ScaledConv2d( ScaledConv2d(
in_channels=layer2_channels, out_channels=layer3_channels, in_channels=layer2_channels,
kernel_size=3, stride=2, out_channels=layer3_channels,
kernel_size=3,
stride=2,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),
) )
self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) self.out = ScaledLinear(
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out` # set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not # itself has learned scale, so the extra degree of freedom is not
# needed. # needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False) self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero. # constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(channel_dim=-1, self.out_balancer = ActivationBalancer(
min_positive=0.45, channel_dim=-1, min_positive=0.45, max_positive=0.55
max_positive=0.55) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.
@ -1009,13 +1025,14 @@ class Conv2dSubsampling(nn.Module):
return x return x
if __name__ == "__main__":
if __name__ == '__main__':
feature_dim = 50 feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4) c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
f = c(torch.randn(batch_size, seq_len, feature_dim), f = c(
torch.full((batch_size,), seq_len, dtype=torch.int64), torch.randn(batch_size, seq_len, feature_dim),
warmup=0.5) torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)

View File

@ -16,15 +16,17 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from scaling import ScaledLinear from scaling import ScaledLinear
class Joiner(nn.Module): class Joiner(nn.Module):
def __init__(self, def __init__(
encoder_dim: int, self,
decoder_dim: int, encoder_dim: int,
joiner_dim: int, decoder_dim: int,
vocab_size: int): joiner_dim: int,
vocab_size: int,
):
super().__init__() super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
@ -32,8 +34,10 @@ class Joiner(nn.Module):
self.output_linear = ScaledLinear(joiner_dim, vocab_size) self.output_linear = ScaledLinear(joiner_dim, vocab_size)
def forward( def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, self,
project_input: bool = True encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -52,7 +56,9 @@ class Joiner(nn.Module):
assert encoder_out.shape[:-1] == decoder_out.shape[:-1] assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
else: else:
logit = encoder_out + decoder_out logit = encoder_out + decoder_out

View File

@ -37,7 +37,7 @@ class Transducer(nn.Module):
encoder_dim: int, encoder_dim: int,
decoder_dim: int, decoder_dim: int,
joiner_dim: int, joiner_dim: int,
vocab_size: int vocab_size: int,
): ):
""" """
Args: Args:
@ -48,11 +48,11 @@ class Transducer(nn.Module):
`logit_lens` of shape (N,). `logit_lens` of shape (N,).
decoder: decoder:
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim). It should contain is (N, U) and its output shape is (N, U, decoder_dim).
one attribute: `blank_id`. It should contain one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
output shape is (N, T, U, vocab_size). Note that its output contains Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
""" """
super().__init__() super().__init__()
@ -63,8 +63,9 @@ class Transducer(nn.Module):
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, self.simple_am_proj = ScaledLinear(
initial_speed=0.5) encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
def forward( def forward(
@ -166,15 +167,14 @@ class Transducer(nn.Module):
am_pruned, lm_pruned = k2.do_rnnt_pruning( am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out), am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out), lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges ranges=ranges,
) )
# logits : [B, T, prune_range, vocab_size] # logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, logits = self.joiner(am_pruned, lm_pruned, project_input=False)
project_input=False)
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits, logits=logits,

View File

@ -15,54 +15,86 @@
# limitations under the License. # limitations under the License.
import collections
from itertools import repeat
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Tuple, Optional
def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
class ActivationBalancerFunction(torch.autograd.Function): class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, def forward(
channel_dim: int, ctx,
min_positive: float, # e.g. 0.05 x: Tensor,
max_positive: float, # e.g. 0.95 channel_dim: int,
max_factor: float, # e.g. 0.01 min_positive: float, # e.g. 0.05
min_abs: float, # e.g. 0.2 max_positive: float, # e.g. 0.95
max_abs: float, # e.g. 100.0 max_factor: float, # e.g. 0.01
min_abs: float, # e.g. 0.2
max_abs: float, # e.g. 100.0
) -> Tensor: ) -> Tensor:
if x.requires_grad: if x.requires_grad:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
xgt0 = x > 0 xgt0 = x > 0
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) proportion_positive = torch.mean(
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) xgt0.to(x.dtype), dim=sum_dims, keepdim=True
if min_positive != 0.0 else 0.0) )
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) factor1 = (
if max_positive != 1.0 else 0.0) (min_positive - proportion_positive).relu()
* (max_factor / min_positive)
if min_positive != 0.0
else 0.0
)
factor2 = (
(proportion_positive - max_positive).relu()
* (max_factor / (max_positive - 1.0))
if max_positive != 1.0
else 0.0
)
factor = factor1 + factor2 factor = factor1 + factor2
if isinstance(factor, float): if isinstance(factor, float):
factor = torch.zeros_like(proportion_positive) factor = torch.zeros_like(proportion_positive)
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
below_threshold = (mean_abs < min_abs) below_threshold = mean_abs < min_abs
above_threshold = (mean_abs > max_abs) above_threshold = mean_abs > max_abs
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.save_for_backward(
factor, xgt0, below_threshold, above_threshold
)
ctx.max_factor = max_factor ctx.max_factor = max_factor
ctx.sum_dims = sum_dims ctx.sum_dims = sum_dims
return x return x
@staticmethod @staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None, None]:
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype dtype = x_grad.dtype
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * scale_factor = (
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) (below_threshold.to(dtype) - above_threshold.to(dtype))
* (xgt0.to(dtype) - 0.5)
* (ctx.max_factor * 2.0)
)
neg_delta_grad = x_grad.abs() * (factor + scale_factor) neg_delta_grad = x_grad.abs() * (factor + scale_factor)
return x_grad - neg_delta_grad, None, None, None, None, None, None return x_grad - neg_delta_grad, None, None, None, None, None, None
@ -95,29 +127,31 @@ class BasicNorm(torch.nn.Module):
learn_eps: if true, we learn epsilon; if false, we keep it learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value. at the initial value.
""" """
def __init__(self,
num_channels: int, def __init__(
channel_dim: int = -1, # CAUTION: see documentation. self,
eps: float = 0.25, num_channels: int,
learn_eps: bool = True) -> None: channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25,
learn_eps: bool = True,
) -> None:
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
if learn_eps: if learn_eps:
self.eps = nn.Parameter(torch.tensor(eps).log().detach()) self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else: else:
self.register_buffer('eps', torch.tensor(eps).log().detach()) self.register_buffer("eps", torch.tensor(eps).log().detach())
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + scales = (
self.eps.exp()) ** -0.5 torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
) ** -0.5
return x * scales return x * scales
class ScaledLinear(nn.Linear): class ScaledLinear(nn.Linear):
""" """
A modified version of nn.Linear where the parameters are scaled before A modified version of nn.Linear where the parameters are scaled before
@ -143,19 +177,25 @@ class ScaledLinear(nn.Linear):
Alternatively you can set it to more than 1 if you want it to Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0. initially train faster. Must be greater than 0.
""" """
def __init__(self, *args,
initial_scale: float = 1.0, def __init__(
initial_speed: float = 1.0, self,
**kwargs): *args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledLinear, self).__init__(*args, **kwargs) super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter("bias_scale", None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self, initial_speed: float): def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed std = 0.1 / initial_speed
@ -172,28 +212,33 @@ class ScaledLinear(nn.Linear):
return self.weight * self.weight_scale.exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return (None if self.bias is None else return None if self.bias is None else self.bias * self.bias_scale.exp()
self.bias * self.bias_scale.exp())
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(input, self.get_weight(), return torch.nn.functional.linear(
self.get_bias()) input, self.get_weight(), self.get_bias()
)
class ScaledConv1d(nn.Conv1d): class ScaledConv1d(nn.Conv1d):
# See docs for ScaledLinear # See docs for ScaledLinear
def __init__(self, *args, def __init__(
initial_scale: float = 1.0, self,
initial_speed: float = 1.0, *args,
**kwargs): initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledConv1d, self).__init__(*args, **kwargs) super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter("bias_scale", None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float): def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed std = 0.1 / initial_speed
@ -206,39 +251,58 @@ class ScaledConv1d(nn.Conv1d):
with torch.no_grad(): with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log() self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self): def get_weight(self):
return self.weight * self.weight_scale.exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return (None if self.bias is None else return None if self.bias is None else self.bias * self.bias_scale.exp()
self.bias * self.bias_scale.exp())
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional F = torch.nn.functional
if self.padding_mode != 'zeros': if self.padding_mode != "zeros":
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), return F.conv1d(
self.get_weight(), self.get_bias(), self.stride, F.pad(
_single(0), self.dilation, self.groups) input,
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, self._reversed_padding_repeated_twice,
self.padding, self.dilation, self.groups) mode=self.padding_mode,
),
self.get_weight(),
self.get_bias(),
self.stride,
_single(0),
self.dilation,
self.groups,
)
return F.conv1d(
input,
self.get_weight(),
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
class ScaledConv2d(nn.Conv2d): class ScaledConv2d(nn.Conv2d):
# See docs for ScaledLinear # See docs for ScaledLinear
def __init__(self, *args, def __init__(
initial_scale: float = 1.0, self,
initial_speed: float = 1.0, *args,
**kwargs): initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledConv2d, self).__init__(*args, **kwargs) super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter("bias_scale", None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float): def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed std = 0.1 / initial_speed
@ -251,29 +315,42 @@ class ScaledConv2d(nn.Conv2d):
with torch.no_grad(): with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log() self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self): def get_weight(self):
return self.weight * self.weight_scale.exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return (None if self.bias is None else return None if self.bias is None else self.bias * self.bias_scale.exp()
self.bias * self.bias_scale.exp())
def _conv_forward(self, input, weight): def _conv_forward(self, input, weight):
F = torch.nn.functional F = torch.nn.functional
if self.padding_mode != 'zeros': if self.padding_mode != "zeros":
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), return F.conv2d(
weight, self.get_bias(), self.stride, F.pad(
_pair(0), self.dilation, self.groups) input,
return F.conv2d(input, weight, self.get_bias(), self.stride, self._reversed_padding_repeated_twice,
self.padding, self.dilation, self.groups) mode=self.padding_mode,
),
weight,
self.get_bias(),
self.stride,
_pair(0),
self.dilation,
self.groups,
)
return F.conv2d(
input,
weight,
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.get_weight()) return self._conv_forward(input, self.get_weight())
class ActivationBalancer(torch.nn.Module): class ActivationBalancer(torch.nn.Module):
""" """
Modifies the backpropped derivatives of a function to try to encourage, for Modifies the backpropped derivatives of a function to try to encourage, for
@ -302,12 +379,16 @@ class ActivationBalancer(torch.nn.Module):
we allow, before we start to modify the derivatives to prevent we allow, before we start to modify the derivatives to prevent
this. this.
""" """
def __init__(self, channel_dim: int,
min_positive: float = 0.05, def __init__(
max_positive: float = 0.95, self,
max_factor: float = 0.01, channel_dim: int,
min_abs: float = 0.2, min_positive: float = 0.05,
max_abs: float = 100.0): max_positive: float = 0.95,
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0,
):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.min_positive = min_positive self.min_positive = min_positive
@ -317,10 +398,15 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return ActivationBalancerFunction.apply(x, self.channel_dim, return ActivationBalancerFunction.apply(
self.min_positive, self.max_positive, x,
self.max_factor, self.min_abs, self.channel_dim,
self.max_abs) self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function): class DoubleSwishFunction(torch.autograd.Function):
@ -338,6 +424,7 @@ class DoubleSwishFunction(torch.autograd.Function):
= double_swish(x) * (1-s(x)) + s(x) = double_swish(x) * (1-s(x)) + s(x)
... so we just need to remember s(x) but not x itself. ... so we just need to remember s(x) but not x itself.
""" """
@staticmethod @staticmethod
def forward(ctx, x: Tensor) -> Tensor: def forward(ctx, x: Tensor) -> Tensor:
x = x.detach() x = x.detach()
@ -349,18 +436,17 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor: def backward(ctx, y_grad: Tensor) -> Tensor:
s, y = ctx.saved_tensors s, y = ctx.saved_tensors
return (y * (1-s) + s) * y_grad return (y * (1 - s) + s) * y_grad
class DoubleSwish(torch.nn.Module): class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return double-swish activation function which is an approximation to Swish(Swish(x)), """Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1). that we approximate closely with x * sigmoid(x-1).
""" """
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module): class ScaledEmbedding(nn.Module):
r"""This is a modified version of nn.Embedding that introduces a learnable scale r"""This is a modified version of nn.Embedding that introduces a learnable scale
on the parameters. Note: due to how we initialize it, it's best used with on the parameters. Note: due to how we initialize it, it's best used with
@ -443,8 +529,13 @@ class ScaledEmbedding(nn.Module):
[-0.1655, 0.9897, 0.0635]]]) [-0.1655, 0.9897, 0.0635]]])
""" """
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', __constants__ = [
'scale_grad_by_freq', 'sparse'] "num_embeddings",
"embedding_dim",
"padding_idx",
"scale_grad_by_freq",
"sparse",
]
num_embeddings: int num_embeddings: int
embedding_dim: int embedding_dim: int
@ -453,33 +544,41 @@ class ScaledEmbedding(nn.Module):
weight: Tensor weight: Tensor
sparse: bool sparse: bool
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, def __init__(
scale_grad_by_freq: bool = False, self,
sparse: bool = False, num_embeddings: int,
initial_speed: float = 1.0) -> None: embedding_dim: int,
padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
initial_speed: float = 1.0,
) -> None:
super(ScaledEmbedding, self).__init__() super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
if padding_idx is not None: if padding_idx is not None:
if padding_idx > 0: if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
elif padding_idx < 0: elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq self.scale_grad_by_freq = scale_grad_by_freq
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.sparse = sparse self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters(initial_speed) self.reset_parameters(initial_speed)
def reset_parameters(self, initial_speed: float = 1.0) -> None: def reset_parameters(self, initial_speed: float = 1.0) -> None:
std = 0.1 / initial_speed std = 0.1 / initial_speed
nn.init.normal_(self.weight, std=std) nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
if self.padding_idx is not None: if self.padding_idx is not None:
with torch.no_grad(): with torch.no_grad():
@ -489,36 +588,53 @@ class ScaledEmbedding(nn.Module):
F = torch.nn.functional F = torch.nn.functional
scale = self.scale.exp() scale = self.scale.exp()
if input.numel() < self.num_embeddings: if input.numel() < self.num_embeddings:
return F.embedding( return (
input, self.weight, self.padding_idx, F.embedding(
None, 2.0, # None, 2.0 relate to normalization input,
self.scale_grad_by_freq, self.sparse) * scale self.weight,
self.padding_idx,
None,
2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq,
self.sparse,
)
* scale
)
else: else:
return F.embedding( return F.embedding(
input, self.weight * scale, self.padding_idx, input,
None, 2.0, # None, 2.0 relates to normalization self.weight * scale,
self.scale_grad_by_freq, self.sparse) self.padding_idx,
None,
2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq,
self.sparse,
)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}, scale={scale}' s = "{num_embeddings}, {embedding_dim}, scale={scale}"
if self.padding_idx is not None: if self.padding_idx is not None:
s += ', padding_idx={padding_idx}' s += ", padding_idx={padding_idx}"
if self.scale_grad_by_freq is not False: if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}' s += ", scale_grad_by_freq={scale_grad_by_freq}"
if self.sparse is not False: if self.sparse is not False:
s += ', sparse=True' s += ", sparse=True"
return s.format(**self.__dict__) return s.format(**self.__dict__)
def _test_activation_balancer_sign(): def _test_activation_balancer_sign():
channel_dim = 0
probs = torch.arange(0, 1, 0.01) probs = torch.arange(0, 1, 0.01)
N = 1000 N = 1000
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, m = ActivationBalancer(
max_factor=0.2, min_abs=0.0) channel_dim=0,
min_positive=0.05,
max_positive=0.95,
max_factor=0.2,
min_abs=0.0,
)
y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad = torch.sign(torch.randn(probs.numel(), N))
@ -528,17 +644,23 @@ def _test_activation_balancer_sign():
print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: y grad = ", y_grad)
print("_test_activation_balancer_sign: x grad = ", x.grad) print("_test_activation_balancer_sign: x grad = ", x.grad)
def _test_activation_balancer_magnitude(): def _test_activation_balancer_magnitude():
channel_dim = 0
magnitudes = torch.arange(0, 1, 0.01) magnitudes = torch.arange(0, 1, 0.01)
N = 1000 N = 1000
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
-1
)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = ActivationBalancer(channel_dim=0, m = ActivationBalancer(
min_positive=0.0, max_positive=1.0, channel_dim=0,
max_factor=0.2, min_positive=0.0,
min_abs=0.2, max_abs=0.8) max_positive=1.0,
max_factor=0.2,
min_abs=0.2,
max_abs=0.8,
)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
@ -558,8 +680,8 @@ def _test_basic_norm():
y = m(x) y = m(x)
assert y.shape == x.shape assert y.shape == x.shape
x_rms = (x**2).mean().sqrt() x_rms = (x ** 2).mean().sqrt()
y_rms = (y**2).mean().sqrt() y_rms = (y ** 2).mean().sqrt()
print("x rms = ", x_rms) print("x rms = ", x_rms)
print("y rms = ", y_rms) print("y rms = ", y_rms)
assert y_rms < x_rms assert y_rms < x_rms
@ -573,7 +695,7 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x) torch.autograd.gradcheck(m, x)
if __name__ == '__main__': if __name__ == "__main__":
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_basic_norm() _test_basic_norm()

View File

@ -29,11 +29,11 @@ from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
# use duck typing for LRScheduler since we have different possibilities, see # use duck typing for LRScheduler since we have different possibilities, see
# our class LRScheduler. # our class LRScheduler.
LRSchedulerType = object LRSchedulerType = object
def save_checkpoint( def save_checkpoint(
filename: Path, filename: Path,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],