mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
check some files
This commit is contained in:
parent
34aad74a2c
commit
05fd40ba68
3
.flake8
3
.flake8
@ -7,6 +7,9 @@ per-file-ignores =
|
||||
egs/librispeech/ASR/*/conformer.py: E501,
|
||||
egs/aishell/ASR/*/conformer.py: E501,
|
||||
egs/tedlium3/ASR/*/conformer.py: E501,
|
||||
egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py: E501,
|
||||
egs/librispeech/ASR/pruned_transducer_stateless2/model.py: E501,
|
||||
|
||||
# invalid escape sequence (cause by tex formular), W605
|
||||
icefall/utils.py: E501, W605
|
||||
|
||||
|
@ -16,13 +16,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from encoder_interface import EncoderInterface
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Sequence
|
||||
from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
@ -42,6 +49,7 @@ class Conformer(EncoderInterface):
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
@ -80,9 +88,8 @@ class Conformer(EncoderInterface):
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -112,8 +119,9 @@ class Conformer(EncoderInterface):
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
||||
warmup=warmup) # (T, N, C)
|
||||
x = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
@ -176,18 +184,15 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55,
|
||||
max_abs=6.0)
|
||||
self.balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
@ -220,14 +225,17 @@ class ConformerEncoderLayer(nn.Module):
|
||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||
# completely bypass it.
|
||||
if self.training:
|
||||
alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1
|
||||
alpha = (
|
||||
warmup_scale
|
||||
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
||||
else 0.1
|
||||
)
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
@ -248,7 +256,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1-alpha) * src_orig
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
return src
|
||||
|
||||
@ -275,14 +283,13 @@ class ConformerEncoder(nn.Module):
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
@ -302,8 +309,6 @@ class ConformerEncoder(nn.Module):
|
||||
"""
|
||||
output = src
|
||||
|
||||
num_layers = len(self.layers)
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output = mod(
|
||||
output,
|
||||
@ -428,7 +433,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
||||
)
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||
@ -621,7 +628,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
||||
q, k, v = nn.functional.linear(
|
||||
query, in_proj_weight, in_proj_bias
|
||||
).chunk(3, dim=-1)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
@ -653,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
@ -672,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_b = _b[_start:]
|
||||
v = nn.functional.linear(value, _w, _b)
|
||||
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.dtype == torch.float32
|
||||
@ -864,9 +871,9 @@ class ConvolutionModule(nn.Module):
|
||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||
# it will be in a better position to start learning something, i.e. to latch onto
|
||||
# the correct range.
|
||||
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
|
||||
min_positive=0.05,
|
||||
max_positive=1.0)
|
||||
self.deriv_balancer1 = ActivationBalancer(
|
||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.depthwise_conv = ScaledConv1d(
|
||||
channels,
|
||||
@ -878,9 +885,9 @@ class ConvolutionModule(nn.Module):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
|
||||
min_positive=0.05,
|
||||
max_positive=1.0)
|
||||
self.deriv_balancer2 = ActivationBalancer(
|
||||
channel_dim=1, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.activation = DoubleSwish()
|
||||
|
||||
@ -891,7 +898,7 @@ class ConvolutionModule(nn.Module):
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
initial_scale=0.25
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
@ -924,7 +931,6 @@ class ConvolutionModule(nn.Module):
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
@ -936,11 +942,14 @@ class Conv2dSubsampling(nn.Module):
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
@ -958,34 +967,41 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1, out_channels=layer1_channels,
|
||||
kernel_size=3, padding=1,
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels, out_channels=layer2_channels,
|
||||
kernel_size=3, stride=2,
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels, out_channels=layer3_channels,
|
||||
kernel_size=3, stride=2,
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||
)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55)
|
||||
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
@ -1009,13 +1025,14 @@ class Conv2dSubsampling(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
# Just make sure the forward pass runs.
|
||||
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5)
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5,
|
||||
)
|
||||
|
@ -16,15 +16,17 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
|
||||
@ -32,8 +34,10 @@ class Joiner(nn.Module):
|
||||
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor,
|
||||
project_input: bool = True
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -52,7 +56,9 @@ class Joiner(nn.Module):
|
||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
decoder_out
|
||||
)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
|
@ -37,7 +37,7 @@ class Transducer(nn.Module):
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -48,11 +48,11 @@ class Transducer(nn.Module):
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim). It should contain
|
||||
one attribute: `blank_id`.
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its
|
||||
output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
@ -63,8 +63,9 @@ class Transducer(nn.Module):
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size,
|
||||
initial_speed=0.5)
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim, vocab_size, initial_speed=0.5
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -166,15 +167,14 @@ class Transducer(nn.Module):
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned,
|
||||
project_input=False)
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits,
|
||||
|
@ -15,54 +15,86 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import collections
|
||||
from itertools import repeat
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
_single = _ntuple(1)
|
||||
_pair = _ntuple(2)
|
||||
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor,
|
||||
channel_dim: int,
|
||||
min_positive: float, # e.g. 0.05
|
||||
max_positive: float, # e.g. 0.95
|
||||
max_factor: float, # e.g. 0.01
|
||||
min_abs: float, # e.g. 0.2
|
||||
max_abs: float, # e.g. 100.0
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
channel_dim: int,
|
||||
min_positive: float, # e.g. 0.05
|
||||
max_positive: float, # e.g. 0.95
|
||||
max_factor: float, # e.g. 0.01
|
||||
min_abs: float, # e.g. 0.2
|
||||
max_abs: float, # e.g. 100.0
|
||||
) -> Tensor:
|
||||
if x.requires_grad:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
xgt0 = x > 0
|
||||
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
||||
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
|
||||
if min_positive != 0.0 else 0.0)
|
||||
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
|
||||
if max_positive != 1.0 else 0.0)
|
||||
proportion_positive = torch.mean(
|
||||
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
||||
)
|
||||
factor1 = (
|
||||
(min_positive - proportion_positive).relu()
|
||||
* (max_factor / min_positive)
|
||||
if min_positive != 0.0
|
||||
else 0.0
|
||||
)
|
||||
factor2 = (
|
||||
(proportion_positive - max_positive).relu()
|
||||
* (max_factor / (max_positive - 1.0))
|
||||
if max_positive != 1.0
|
||||
else 0.0
|
||||
)
|
||||
factor = factor1 + factor2
|
||||
if isinstance(factor, float):
|
||||
factor = torch.zeros_like(proportion_positive)
|
||||
|
||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||
below_threshold = (mean_abs < min_abs)
|
||||
above_threshold = (mean_abs > max_abs)
|
||||
below_threshold = mean_abs < min_abs
|
||||
above_threshold = mean_abs > max_abs
|
||||
|
||||
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
|
||||
ctx.save_for_backward(
|
||||
factor, xgt0, below_threshold, above_threshold
|
||||
)
|
||||
ctx.max_factor = max_factor
|
||||
ctx.sum_dims = sum_dims
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||
def backward(
|
||||
ctx, x_grad: Tensor
|
||||
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||
dtype = x_grad.dtype
|
||||
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
||||
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
||||
scale_factor = (
|
||||
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
||||
* (xgt0.to(dtype) - 0.5)
|
||||
* (ctx.max_factor * 2.0)
|
||||
)
|
||||
|
||||
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||
@ -95,29 +127,31 @@ class BasicNorm(torch.nn.Module):
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True,
|
||||
) -> None:
|
||||
super(BasicNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
self.register_buffer('eps', torch.tensor(eps).log().detach())
|
||||
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
|
||||
self.eps.exp()) ** -0.5
|
||||
scales = (
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||
+ self.eps.exp()
|
||||
) ** -0.5
|
||||
return x * scales
|
||||
|
||||
|
||||
|
||||
|
||||
class ScaledLinear(nn.Linear):
|
||||
"""
|
||||
A modified version of nn.Linear where the parameters are scaled before
|
||||
@ -143,19 +177,25 @@ class ScaledLinear(nn.Linear):
|
||||
Alternatively you can set it to more than 1 if you want it to
|
||||
initially train faster. Must be greater than 0.
|
||||
"""
|
||||
def __init__(self, *args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
self.register_parameter("bias_scale", None)
|
||||
|
||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in nn.Linear
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
@ -172,28 +212,33 @@ class ScaledLinear(nn.Linear):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return (None if self.bias is None else
|
||||
self.bias * self.bias_scale.exp())
|
||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return torch.nn.functional.linear(input, self.get_weight(),
|
||||
self.get_bias())
|
||||
return torch.nn.functional.linear(
|
||||
input, self.get_weight(), self.get_bias()
|
||||
)
|
||||
|
||||
|
||||
class ScaledConv1d(nn.Conv1d):
|
||||
# See docs for ScaledLinear
|
||||
def __init__(self, *args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
||||
self.register_parameter("bias_scale", None)
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
@ -206,39 +251,58 @@ class ScaledConv1d(nn.Conv1d):
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return (None if self.bias is None else
|
||||
self.bias * self.bias_scale.exp())
|
||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
self.get_weight(), self.get_bias(), self.stride,
|
||||
_single(0), self.dilation, self.groups)
|
||||
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
if self.padding_mode != "zeros":
|
||||
return F.conv1d(
|
||||
F.pad(
|
||||
input,
|
||||
self._reversed_padding_repeated_twice,
|
||||
mode=self.padding_mode,
|
||||
),
|
||||
self.get_weight(),
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
_single(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return F.conv1d(
|
||||
input,
|
||||
self.get_weight(),
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
class ScaledConv2d(nn.Conv2d):
|
||||
# See docs for ScaledLinear
|
||||
def __init__(self, *args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
||||
self.register_parameter("bias_scale", None)
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
@ -251,29 +315,42 @@ class ScaledConv2d(nn.Conv2d):
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return (None if self.bias is None else
|
||||
self.bias * self.bias_scale.exp())
|
||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||
|
||||
def _conv_forward(self, input, weight):
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, self.get_bias(), self.stride,
|
||||
_pair(0), self.dilation, self.groups)
|
||||
return F.conv2d(input, weight, self.get_bias(), self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
if self.padding_mode != "zeros":
|
||||
return F.conv2d(
|
||||
F.pad(
|
||||
input,
|
||||
self._reversed_padding_repeated_twice,
|
||||
mode=self.padding_mode,
|
||||
),
|
||||
weight,
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
_pair(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return F.conv2d(
|
||||
input,
|
||||
weight,
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self._conv_forward(input, self.get_weight())
|
||||
|
||||
|
||||
|
||||
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
"""
|
||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||
@ -302,12 +379,16 @@ class ActivationBalancer(torch.nn.Module):
|
||||
we allow, before we start to modify the derivatives to prevent
|
||||
this.
|
||||
"""
|
||||
def __init__(self, channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
max_positive: float = 0.95,
|
||||
max_factor: float = 0.01,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
max_positive: float = 0.95,
|
||||
max_factor: float = 0.01,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
self.channel_dim = channel_dim
|
||||
self.min_positive = min_positive
|
||||
@ -317,10 +398,15 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return ActivationBalancerFunction.apply(x, self.channel_dim,
|
||||
self.min_positive, self.max_positive,
|
||||
self.max_factor, self.min_abs,
|
||||
self.max_abs)
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
@ -338,6 +424,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
= double_swish(x) * (1-s(x)) + s(x)
|
||||
... so we just need to remember s(x) but not x itself.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
x = x.detach()
|
||||
@ -349,18 +436,17 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
s, y = ctx.saved_tensors
|
||||
return (y * (1-s) + s) * y_grad
|
||||
return (y * (1 - s) + s) * y_grad
|
||||
|
||||
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
||||
on the parameters. Note: due to how we initialize it, it's best used with
|
||||
@ -443,8 +529,13 @@ class ScaledEmbedding(nn.Module):
|
||||
[-0.1655, 0.9897, 0.0635]]])
|
||||
|
||||
"""
|
||||
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
||||
'scale_grad_by_freq', 'sparse']
|
||||
__constants__ = [
|
||||
"num_embeddings",
|
||||
"embedding_dim",
|
||||
"padding_idx",
|
||||
"scale_grad_by_freq",
|
||||
"sparse",
|
||||
]
|
||||
|
||||
num_embeddings: int
|
||||
embedding_dim: int
|
||||
@ -453,33 +544,41 @@ class ScaledEmbedding(nn.Module):
|
||||
weight: Tensor
|
||||
sparse: bool
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
initial_speed: float = 1.0) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: Optional[int] = None,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
initial_speed: float = 1.0,
|
||||
) -> None:
|
||||
super(ScaledEmbedding, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
assert (
|
||||
padding_idx < self.num_embeddings
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
elif padding_idx < 0:
|
||||
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
assert (
|
||||
padding_idx >= -self.num_embeddings
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
|
||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
||||
self.sparse = sparse
|
||||
|
||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||
self.reset_parameters(initial_speed)
|
||||
|
||||
|
||||
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
||||
std = 0.1 / initial_speed
|
||||
nn.init.normal_(self.weight, std=std)
|
||||
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
||||
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
|
||||
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
@ -489,36 +588,53 @@ class ScaledEmbedding(nn.Module):
|
||||
F = torch.nn.functional
|
||||
scale = self.scale.exp()
|
||||
if input.numel() < self.num_embeddings:
|
||||
return F.embedding(
|
||||
input, self.weight, self.padding_idx,
|
||||
None, 2.0, # None, 2.0 relate to normalization
|
||||
self.scale_grad_by_freq, self.sparse) * scale
|
||||
return (
|
||||
F.embedding(
|
||||
input,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
None,
|
||||
2.0, # None, 2.0 relate to normalization
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
return F.embedding(
|
||||
input, self.weight * scale, self.padding_idx,
|
||||
None, 2.0, # None, 2.0 relates to normalization
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
input,
|
||||
self.weight * scale,
|
||||
self.padding_idx,
|
||||
None,
|
||||
2.0, # None, 2.0 relates to normalization
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = '{num_embeddings}, {embedding_dim}, scale={scale}'
|
||||
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
||||
if self.padding_idx is not None:
|
||||
s += ', padding_idx={padding_idx}'
|
||||
s += ", padding_idx={padding_idx}"
|
||||
if self.scale_grad_by_freq is not False:
|
||||
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
||||
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
||||
if self.sparse is not False:
|
||||
s += ', sparse=True'
|
||||
s += ", sparse=True"
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
def _test_activation_balancer_sign():
|
||||
channel_dim = 0
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
||||
max_factor=0.2, min_abs=0.0)
|
||||
m = ActivationBalancer(
|
||||
channel_dim=0,
|
||||
min_positive=0.05,
|
||||
max_positive=0.95,
|
||||
max_factor=0.2,
|
||||
min_abs=0.0,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||
|
||||
@ -528,17 +644,23 @@ def _test_activation_balancer_sign():
|
||||
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
||||
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
||||
|
||||
|
||||
def _test_activation_balancer_magnitude():
|
||||
channel_dim = 0
|
||||
magnitudes = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
||||
-1
|
||||
)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(channel_dim=0,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
max_factor=0.2,
|
||||
min_abs=0.2, max_abs=0.8)
|
||||
m = ActivationBalancer(
|
||||
channel_dim=0,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
max_factor=0.2,
|
||||
min_abs=0.2,
|
||||
max_abs=0.8,
|
||||
)
|
||||
|
||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||
|
||||
@ -558,8 +680,8 @@ def _test_basic_norm():
|
||||
y = m(x)
|
||||
|
||||
assert y.shape == x.shape
|
||||
x_rms = (x**2).mean().sqrt()
|
||||
y_rms = (y**2).mean().sqrt()
|
||||
x_rms = (x ** 2).mean().sqrt()
|
||||
y_rms = (y ** 2).mean().sqrt()
|
||||
print("x rms = ", x_rms)
|
||||
print("y rms = ", y_rms)
|
||||
assert y_rms < x_rms
|
||||
@ -573,7 +695,7 @@ def _test_double_swish_deriv():
|
||||
torch.autograd.gradcheck(m, x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
|
@ -29,11 +29,11 @@ from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
# use duck typing for LRScheduler since we have different possibilities, see
|
||||
# our class LRScheduler.
|
||||
LRSchedulerType = object
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
filename: Path,
|
||||
model: Union[nn.Module, DDP],
|
||||
|
Loading…
x
Reference in New Issue
Block a user