mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
check some files
This commit is contained in:
parent
34aad74a2c
commit
05fd40ba68
3
.flake8
3
.flake8
@ -7,6 +7,9 @@ per-file-ignores =
|
|||||||
egs/librispeech/ASR/*/conformer.py: E501,
|
egs/librispeech/ASR/*/conformer.py: E501,
|
||||||
egs/aishell/ASR/*/conformer.py: E501,
|
egs/aishell/ASR/*/conformer.py: E501,
|
||||||
egs/tedlium3/ASR/*/conformer.py: E501,
|
egs/tedlium3/ASR/*/conformer.py: E501,
|
||||||
|
egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py: E501,
|
||||||
|
egs/librispeech/ASR/pruned_transducer_stateless2/model.py: E501,
|
||||||
|
|
||||||
# invalid escape sequence (cause by tex formular), W605
|
# invalid escape sequence (cause by tex formular), W605
|
||||||
icefall/utils.py: E501, W605
|
icefall/utils.py: E501, W605
|
||||||
|
|
||||||
|
@ -16,13 +16,20 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from encoder_interface import EncoderInterface
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple
|
||||||
from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import (
|
||||||
|
ActivationBalancer,
|
||||||
|
BasicNorm,
|
||||||
|
DoubleSwish,
|
||||||
|
ScaledConv1d,
|
||||||
|
ScaledConv2d,
|
||||||
|
ScaledLinear,
|
||||||
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
@ -42,6 +49,7 @@ class Conformer(EncoderInterface):
|
|||||||
cnn_module_kernel (int): Kernel size of convolution module
|
cnn_module_kernel (int): Kernel size of convolution module
|
||||||
vgg_frontend (bool): whether to use vgg frontend.
|
vgg_frontend (bool): whether to use vgg frontend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
@ -80,9 +88,8 @@ class Conformer(EncoderInterface):
|
|||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -112,8 +119,9 @@ class Conformer(EncoderInterface):
|
|||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
x = self.encoder(
|
||||||
warmup=warmup) # (T, N, C)
|
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
@ -176,18 +184,15 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
self.balancer = ActivationBalancer(channel_dim=-1,
|
self.balancer = ActivationBalancer(
|
||||||
min_positive=0.45,
|
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||||
max_positive=0.55,
|
)
|
||||||
max_abs=6.0)
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -220,14 +225,17 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||||
# completely bypass it.
|
# completely bypass it.
|
||||||
if self.training:
|
if self.training:
|
||||||
alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1
|
alpha = (
|
||||||
|
warmup_scale
|
||||||
|
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
||||||
|
else 0.1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
alpha = 1.0
|
alpha = 1.0
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
src,
|
src,
|
||||||
@ -248,7 +256,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
if alpha != 1.0:
|
if alpha != 1.0:
|
||||||
src = alpha * src + (1-alpha) * src_orig
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
@ -275,14 +283,13 @@ class ConformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0
|
warmup: float = 1.0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -302,8 +309,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
num_layers = len(self.layers)
|
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
@ -428,7 +433,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
|
self.out_proj = ScaledLinear(
|
||||||
|
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
||||||
|
)
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||||
@ -621,7 +628,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
q, k, v = nn.functional.linear(
|
||||||
|
query, in_proj_weight, in_proj_bias
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -653,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = embed_dim
|
_start = embed_dim
|
||||||
@ -672,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
v = nn.functional.linear(value, _w, _b)
|
v = nn.functional.linear(value, _w, _b)
|
||||||
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
attn_mask.dtype == torch.float32
|
attn_mask.dtype == torch.float32
|
||||||
@ -864,9 +871,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
# it will be in a better position to start learning something, i.e. to latch onto
|
||||||
# the correct range.
|
# the correct range.
|
||||||
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
|
self.deriv_balancer1 = ActivationBalancer(
|
||||||
min_positive=0.05,
|
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||||
max_positive=1.0)
|
)
|
||||||
|
|
||||||
self.depthwise_conv = ScaledConv1d(
|
self.depthwise_conv = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
@ -878,9 +885,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
|
self.deriv_balancer2 = ActivationBalancer(
|
||||||
min_positive=0.05,
|
channel_dim=1, min_positive=0.05, max_positive=1.0
|
||||||
max_positive=1.0)
|
)
|
||||||
|
|
||||||
self.activation = DoubleSwish()
|
self.activation = DoubleSwish()
|
||||||
|
|
||||||
@ -891,7 +898,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
initial_scale=0.25
|
initial_scale=0.25,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
@ -924,7 +931,6 @@ class ConvolutionModule(nn.Module):
|
|||||||
return x.permute(2, 0, 1)
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
"""Convolutional 2D subsampling (to 1/4 length).
|
"""Convolutional 2D subsampling (to 1/4 length).
|
||||||
|
|
||||||
@ -936,11 +942,14 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int,
|
def __init__(
|
||||||
out_channels: int,
|
self,
|
||||||
layer1_channels: int = 8,
|
in_channels: int,
|
||||||
layer2_channels: int = 32,
|
out_channels: int,
|
||||||
layer3_channels: int = 128) -> None:
|
layer1_channels: int = 8,
|
||||||
|
layer2_channels: int = 32,
|
||||||
|
layer3_channels: int = 128,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
in_channels:
|
in_channels:
|
||||||
@ -958,34 +967,41 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=1, out_channels=layer1_channels,
|
in_channels=1,
|
||||||
kernel_size=3, padding=1,
|
out_channels=layer1_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=layer1_channels, out_channels=layer2_channels,
|
in_channels=layer1_channels,
|
||||||
kernel_size=3, stride=2,
|
out_channels=layer2_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=layer2_channels, out_channels=layer3_channels,
|
in_channels=layer2_channels,
|
||||||
kernel_size=3, stride=2,
|
out_channels=layer3_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
|
self.out = ScaledLinear(
|
||||||
|
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||||
|
)
|
||||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
# itself has learned scale, so the extra degree of freedom is not
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
# needed.
|
# needed.
|
||||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||||
# constrain median of output to be close to zero.
|
# constrain median of output to be close to zero.
|
||||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
self.out_balancer = ActivationBalancer(
|
||||||
min_positive=0.45,
|
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||||
max_positive=0.55)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Subsample x.
|
"""Subsample x.
|
||||||
@ -1009,13 +1025,14 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
if __name__ == '__main__':
|
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
f = c(
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
warmup=0.5)
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
|
warmup=0.5,
|
||||||
|
)
|
||||||
|
@ -16,15 +16,17 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
encoder_dim: int,
|
self,
|
||||||
decoder_dim: int,
|
encoder_dim: int,
|
||||||
joiner_dim: int,
|
decoder_dim: int,
|
||||||
vocab_size: int):
|
joiner_dim: int,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
|
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
|
||||||
@ -32,8 +34,10 @@ class Joiner(nn.Module):
|
|||||||
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor,
|
self,
|
||||||
project_input: bool = True
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
project_input: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -52,7 +56,9 @@ class Joiner(nn.Module):
|
|||||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||||
|
decoder_out
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ class Transducer(nn.Module):
|
|||||||
encoder_dim: int,
|
encoder_dim: int,
|
||||||
decoder_dim: int,
|
decoder_dim: int,
|
||||||
joiner_dim: int,
|
joiner_dim: int,
|
||||||
vocab_size: int
|
vocab_size: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -48,11 +48,11 @@ class Transducer(nn.Module):
|
|||||||
`logit_lens` of shape (N,).
|
`logit_lens` of shape (N,).
|
||||||
decoder:
|
decoder:
|
||||||
It is the prediction network in the paper. Its input shape
|
It is the prediction network in the paper. Its input shape
|
||||||
is (N, U) and its output shape is (N, U, decoder_dim). It should contain
|
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||||
one attribute: `blank_id`.
|
It should contain one attribute: `blank_id`.
|
||||||
joiner:
|
joiner:
|
||||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its
|
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||||
output shape is (N, T, U, vocab_size). Note that its output contains
|
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -63,8 +63,9 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size,
|
self.simple_am_proj = ScaledLinear(
|
||||||
initial_speed=0.5)
|
encoder_dim, vocab_size, initial_speed=0.5
|
||||||
|
)
|
||||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -166,15 +167,14 @@ class Transducer(nn.Module):
|
|||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
am=self.joiner.encoder_proj(encoder_out),
|
am=self.joiner.encoder_proj(encoder_out),
|
||||||
lm=self.joiner.decoder_proj(decoder_out),
|
lm=self.joiner.decoder_proj(decoder_out),
|
||||||
ranges=ranges
|
ranges=ranges,
|
||||||
)
|
)
|
||||||
|
|
||||||
# logits : [B, T, prune_range, vocab_size]
|
# logits : [B, T, prune_range, vocab_size]
|
||||||
|
|
||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned,
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
project_input=False)
|
|
||||||
|
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
|
@ -15,54 +15,86 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import collections
|
||||||
|
from itertools import repeat
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Tuple, Optional
|
|
||||||
|
|
||||||
|
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.Iterable):
|
||||||
|
return x
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
_single = _ntuple(1)
|
||||||
|
_pair = _ntuple(2)
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancerFunction(torch.autograd.Function):
|
class ActivationBalancerFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor,
|
def forward(
|
||||||
channel_dim: int,
|
ctx,
|
||||||
min_positive: float, # e.g. 0.05
|
x: Tensor,
|
||||||
max_positive: float, # e.g. 0.95
|
channel_dim: int,
|
||||||
max_factor: float, # e.g. 0.01
|
min_positive: float, # e.g. 0.05
|
||||||
min_abs: float, # e.g. 0.2
|
max_positive: float, # e.g. 0.95
|
||||||
max_abs: float, # e.g. 100.0
|
max_factor: float, # e.g. 0.01
|
||||||
|
min_abs: float, # e.g. 0.2
|
||||||
|
max_abs: float, # e.g. 100.0
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if x.requires_grad:
|
if x.requires_grad:
|
||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim += x.ndim
|
channel_dim += x.ndim
|
||||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||||
xgt0 = x > 0
|
xgt0 = x > 0
|
||||||
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
proportion_positive = torch.mean(
|
||||||
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
|
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
||||||
if min_positive != 0.0 else 0.0)
|
)
|
||||||
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
|
factor1 = (
|
||||||
if max_positive != 1.0 else 0.0)
|
(min_positive - proportion_positive).relu()
|
||||||
|
* (max_factor / min_positive)
|
||||||
|
if min_positive != 0.0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
factor2 = (
|
||||||
|
(proportion_positive - max_positive).relu()
|
||||||
|
* (max_factor / (max_positive - 1.0))
|
||||||
|
if max_positive != 1.0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
factor = factor1 + factor2
|
factor = factor1 + factor2
|
||||||
if isinstance(factor, float):
|
if isinstance(factor, float):
|
||||||
factor = torch.zeros_like(proportion_positive)
|
factor = torch.zeros_like(proportion_positive)
|
||||||
|
|
||||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||||
below_threshold = (mean_abs < min_abs)
|
below_threshold = mean_abs < min_abs
|
||||||
above_threshold = (mean_abs > max_abs)
|
above_threshold = mean_abs > max_abs
|
||||||
|
|
||||||
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
|
ctx.save_for_backward(
|
||||||
|
factor, xgt0, below_threshold, above_threshold
|
||||||
|
)
|
||||||
ctx.max_factor = max_factor
|
ctx.max_factor = max_factor
|
||||||
ctx.sum_dims = sum_dims
|
ctx.sum_dims = sum_dims
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
|
def backward(
|
||||||
|
ctx, x_grad: Tensor
|
||||||
|
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||||
dtype = x_grad.dtype
|
dtype = x_grad.dtype
|
||||||
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
scale_factor = (
|
||||||
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
||||||
|
* (xgt0.to(dtype) - 0.5)
|
||||||
|
* (ctx.max_factor * 2.0)
|
||||||
|
)
|
||||||
|
|
||||||
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||||
@ -95,29 +127,31 @@ class BasicNorm(torch.nn.Module):
|
|||||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||||
at the initial value.
|
at the initial value.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
num_channels: int,
|
def __init__(
|
||||||
channel_dim: int = -1, # CAUTION: see documentation.
|
self,
|
||||||
eps: float = 0.25,
|
num_channels: int,
|
||||||
learn_eps: bool = True) -> None:
|
channel_dim: int = -1, # CAUTION: see documentation.
|
||||||
|
eps: float = 0.25,
|
||||||
|
learn_eps: bool = True,
|
||||||
|
) -> None:
|
||||||
super(BasicNorm, self).__init__()
|
super(BasicNorm, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
if learn_eps:
|
if learn_eps:
|
||||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||||
else:
|
else:
|
||||||
self.register_buffer('eps', torch.tensor(eps).log().detach())
|
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
|
scales = (
|
||||||
self.eps.exp()) ** -0.5
|
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||||
|
+ self.eps.exp()
|
||||||
|
) ** -0.5
|
||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledLinear(nn.Linear):
|
class ScaledLinear(nn.Linear):
|
||||||
"""
|
"""
|
||||||
A modified version of nn.Linear where the parameters are scaled before
|
A modified version of nn.Linear where the parameters are scaled before
|
||||||
@ -143,19 +177,25 @@ class ScaledLinear(nn.Linear):
|
|||||||
Alternatively you can set it to more than 1 if you want it to
|
Alternatively you can set it to more than 1 if you want it to
|
||||||
initially train faster. Must be greater than 0.
|
initially train faster. Must be greater than 0.
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args,
|
|
||||||
initial_scale: float = 1.0,
|
def __init__(
|
||||||
initial_speed: float = 1.0,
|
self,
|
||||||
**kwargs):
|
*args,
|
||||||
|
initial_scale: float = 1.0,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter("bias_scale", None)
|
||||||
|
|
||||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
|
self._reset_parameters(
|
||||||
|
initial_speed
|
||||||
|
) # Overrides the reset_parameters in nn.Linear
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
@ -172,28 +212,33 @@ class ScaledLinear(nn.Linear):
|
|||||||
return self.weight * self.weight_scale.exp()
|
return self.weight * self.weight_scale.exp()
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
return (None if self.bias is None else
|
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||||
self.bias * self.bias_scale.exp())
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
return torch.nn.functional.linear(input, self.get_weight(),
|
return torch.nn.functional.linear(
|
||||||
self.get_bias())
|
input, self.get_weight(), self.get_bias()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScaledConv1d(nn.Conv1d):
|
class ScaledConv1d(nn.Conv1d):
|
||||||
# See docs for ScaledLinear
|
# See docs for ScaledLinear
|
||||||
def __init__(self, *args,
|
def __init__(
|
||||||
initial_scale: float = 1.0,
|
self,
|
||||||
initial_speed: float = 1.0,
|
*args,
|
||||||
**kwargs):
|
initial_scale: float = 1.0,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter("bias_scale", None)
|
||||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
self._reset_parameters(
|
||||||
|
initial_speed
|
||||||
|
) # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
@ -206,39 +251,58 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight_scale += torch.tensor(scale / std).log()
|
||||||
|
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.weight * self.weight_scale.exp()
|
return self.weight * self.weight_scale.exp()
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
return (None if self.bias is None else
|
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||||
self.bias * self.bias_scale.exp())
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
F = torch.nn.functional
|
F = torch.nn.functional
|
||||||
if self.padding_mode != 'zeros':
|
if self.padding_mode != "zeros":
|
||||||
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
return F.conv1d(
|
||||||
self.get_weight(), self.get_bias(), self.stride,
|
F.pad(
|
||||||
_single(0), self.dilation, self.groups)
|
input,
|
||||||
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
|
self._reversed_padding_repeated_twice,
|
||||||
self.padding, self.dilation, self.groups)
|
mode=self.padding_mode,
|
||||||
|
),
|
||||||
|
self.get_weight(),
|
||||||
|
self.get_bias(),
|
||||||
|
self.stride,
|
||||||
|
_single(0),
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
return F.conv1d(
|
||||||
|
input,
|
||||||
|
self.get_weight(),
|
||||||
|
self.get_bias(),
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScaledConv2d(nn.Conv2d):
|
class ScaledConv2d(nn.Conv2d):
|
||||||
# See docs for ScaledLinear
|
# See docs for ScaledLinear
|
||||||
def __init__(self, *args,
|
def __init__(
|
||||||
initial_scale: float = 1.0,
|
self,
|
||||||
initial_speed: float = 1.0,
|
*args,
|
||||||
**kwargs):
|
initial_scale: float = 1.0,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter("bias_scale", None)
|
||||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
self._reset_parameters(
|
||||||
|
initial_speed
|
||||||
|
) # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
@ -251,29 +315,42 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight_scale += torch.tensor(scale / std).log()
|
||||||
|
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.weight * self.weight_scale.exp()
|
return self.weight * self.weight_scale.exp()
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
return (None if self.bias is None else
|
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||||
self.bias * self.bias_scale.exp())
|
|
||||||
|
|
||||||
def _conv_forward(self, input, weight):
|
def _conv_forward(self, input, weight):
|
||||||
F = torch.nn.functional
|
F = torch.nn.functional
|
||||||
if self.padding_mode != 'zeros':
|
if self.padding_mode != "zeros":
|
||||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
return F.conv2d(
|
||||||
weight, self.get_bias(), self.stride,
|
F.pad(
|
||||||
_pair(0), self.dilation, self.groups)
|
input,
|
||||||
return F.conv2d(input, weight, self.get_bias(), self.stride,
|
self._reversed_padding_repeated_twice,
|
||||||
self.padding, self.dilation, self.groups)
|
mode=self.padding_mode,
|
||||||
|
),
|
||||||
|
weight,
|
||||||
|
self.get_bias(),
|
||||||
|
self.stride,
|
||||||
|
_pair(0),
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
return F.conv2d(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
self.get_bias(),
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
return self._conv_forward(input, self.get_weight())
|
return self._conv_forward(input, self.get_weight())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancer(torch.nn.Module):
|
class ActivationBalancer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||||
@ -302,12 +379,16 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
we allow, before we start to modify the derivatives to prevent
|
we allow, before we start to modify the derivatives to prevent
|
||||||
this.
|
this.
|
||||||
"""
|
"""
|
||||||
def __init__(self, channel_dim: int,
|
|
||||||
min_positive: float = 0.05,
|
def __init__(
|
||||||
max_positive: float = 0.95,
|
self,
|
||||||
max_factor: float = 0.01,
|
channel_dim: int,
|
||||||
min_abs: float = 0.2,
|
min_positive: float = 0.05,
|
||||||
max_abs: float = 100.0):
|
max_positive: float = 0.95,
|
||||||
|
max_factor: float = 0.01,
|
||||||
|
min_abs: float = 0.2,
|
||||||
|
max_abs: float = 100.0,
|
||||||
|
):
|
||||||
super(ActivationBalancer, self).__init__()
|
super(ActivationBalancer, self).__init__()
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.min_positive = min_positive
|
self.min_positive = min_positive
|
||||||
@ -317,10 +398,15 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return ActivationBalancerFunction.apply(x, self.channel_dim,
|
return ActivationBalancerFunction.apply(
|
||||||
self.min_positive, self.max_positive,
|
x,
|
||||||
self.max_factor, self.min_abs,
|
self.channel_dim,
|
||||||
self.max_abs)
|
self.min_positive,
|
||||||
|
self.max_positive,
|
||||||
|
self.max_factor,
|
||||||
|
self.min_abs,
|
||||||
|
self.max_abs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DoubleSwishFunction(torch.autograd.Function):
|
class DoubleSwishFunction(torch.autograd.Function):
|
||||||
@ -338,6 +424,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
= double_swish(x) * (1-s(x)) + s(x)
|
= double_swish(x) * (1-s(x)) + s(x)
|
||||||
... so we just need to remember s(x) but not x itself.
|
... so we just need to remember s(x) but not x itself.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
@ -349,18 +436,17 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
s, y = ctx.saved_tensors
|
s, y = ctx.saved_tensors
|
||||||
return (y * (1-s) + s) * y_grad
|
return (y * (1 - s) + s) * y_grad
|
||||||
|
|
||||||
|
|
||||||
class DoubleSwish(torch.nn.Module):
|
class DoubleSwish(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
that we approximate closely with x * sigmoid(x-1).
|
that we approximate closely with x * sigmoid(x-1).
|
||||||
"""
|
"""
|
||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledEmbedding(nn.Module):
|
class ScaledEmbedding(nn.Module):
|
||||||
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
||||||
on the parameters. Note: due to how we initialize it, it's best used with
|
on the parameters. Note: due to how we initialize it, it's best used with
|
||||||
@ -443,8 +529,13 @@ class ScaledEmbedding(nn.Module):
|
|||||||
[-0.1655, 0.9897, 0.0635]]])
|
[-0.1655, 0.9897, 0.0635]]])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
__constants__ = [
|
||||||
'scale_grad_by_freq', 'sparse']
|
"num_embeddings",
|
||||||
|
"embedding_dim",
|
||||||
|
"padding_idx",
|
||||||
|
"scale_grad_by_freq",
|
||||||
|
"sparse",
|
||||||
|
]
|
||||||
|
|
||||||
num_embeddings: int
|
num_embeddings: int
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
@ -453,33 +544,41 @@ class ScaledEmbedding(nn.Module):
|
|||||||
weight: Tensor
|
weight: Tensor
|
||||||
sparse: bool
|
sparse: bool
|
||||||
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
def __init__(
|
||||||
scale_grad_by_freq: bool = False,
|
self,
|
||||||
sparse: bool = False,
|
num_embeddings: int,
|
||||||
initial_speed: float = 1.0) -> None:
|
embedding_dim: int,
|
||||||
|
padding_idx: Optional[int] = None,
|
||||||
|
scale_grad_by_freq: bool = False,
|
||||||
|
sparse: bool = False,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
) -> None:
|
||||||
super(ScaledEmbedding, self).__init__()
|
super(ScaledEmbedding, self).__init__()
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
if padding_idx is not None:
|
if padding_idx is not None:
|
||||||
if padding_idx > 0:
|
if padding_idx > 0:
|
||||||
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
assert (
|
||||||
|
padding_idx < self.num_embeddings
|
||||||
|
), "Padding_idx must be within num_embeddings"
|
||||||
elif padding_idx < 0:
|
elif padding_idx < 0:
|
||||||
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
assert (
|
||||||
|
padding_idx >= -self.num_embeddings
|
||||||
|
), "Padding_idx must be within num_embeddings"
|
||||||
padding_idx = self.num_embeddings + padding_idx
|
padding_idx = self.num_embeddings + padding_idx
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.scale_grad_by_freq = scale_grad_by_freq
|
self.scale_grad_by_freq = scale_grad_by_freq
|
||||||
|
|
||||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
||||||
self.sparse = sparse
|
self.sparse = sparse
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||||
self.reset_parameters(initial_speed)
|
self.reset_parameters(initial_speed)
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
nn.init.normal_(self.weight, std=std)
|
nn.init.normal_(self.weight, std=std)
|
||||||
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
|
||||||
|
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -489,36 +588,53 @@ class ScaledEmbedding(nn.Module):
|
|||||||
F = torch.nn.functional
|
F = torch.nn.functional
|
||||||
scale = self.scale.exp()
|
scale = self.scale.exp()
|
||||||
if input.numel() < self.num_embeddings:
|
if input.numel() < self.num_embeddings:
|
||||||
return F.embedding(
|
return (
|
||||||
input, self.weight, self.padding_idx,
|
F.embedding(
|
||||||
None, 2.0, # None, 2.0 relate to normalization
|
input,
|
||||||
self.scale_grad_by_freq, self.sparse) * scale
|
self.weight,
|
||||||
|
self.padding_idx,
|
||||||
|
None,
|
||||||
|
2.0, # None, 2.0 relate to normalization
|
||||||
|
self.scale_grad_by_freq,
|
||||||
|
self.sparse,
|
||||||
|
)
|
||||||
|
* scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return F.embedding(
|
return F.embedding(
|
||||||
input, self.weight * scale, self.padding_idx,
|
input,
|
||||||
None, 2.0, # None, 2.0 relates to normalization
|
self.weight * scale,
|
||||||
self.scale_grad_by_freq, self.sparse)
|
self.padding_idx,
|
||||||
|
None,
|
||||||
|
2.0, # None, 2.0 relates to normalization
|
||||||
|
self.scale_grad_by_freq,
|
||||||
|
self.sparse,
|
||||||
|
)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = '{num_embeddings}, {embedding_dim}, scale={scale}'
|
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None:
|
||||||
s += ', padding_idx={padding_idx}'
|
s += ", padding_idx={padding_idx}"
|
||||||
if self.scale_grad_by_freq is not False:
|
if self.scale_grad_by_freq is not False:
|
||||||
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
||||||
if self.sparse is not False:
|
if self.sparse is not False:
|
||||||
s += ', sparse=True'
|
s += ", sparse=True"
|
||||||
return s.format(**self.__dict__)
|
return s.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_sign():
|
def _test_activation_balancer_sign():
|
||||||
channel_dim = 0
|
|
||||||
probs = torch.arange(0, 1, 0.01)
|
probs = torch.arange(0, 1, 0.01)
|
||||||
N = 1000
|
N = 1000
|
||||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
m = ActivationBalancer(
|
||||||
max_factor=0.2, min_abs=0.0)
|
channel_dim=0,
|
||||||
|
min_positive=0.05,
|
||||||
|
max_positive=0.95,
|
||||||
|
max_factor=0.2,
|
||||||
|
min_abs=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||||
|
|
||||||
@ -528,17 +644,23 @@ def _test_activation_balancer_sign():
|
|||||||
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
||||||
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_magnitude():
|
def _test_activation_balancer_magnitude():
|
||||||
channel_dim = 0
|
|
||||||
magnitudes = torch.arange(0, 1, 0.01)
|
magnitudes = torch.arange(0, 1, 0.01)
|
||||||
N = 1000
|
N = 1000
|
||||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
||||||
|
-1
|
||||||
|
)
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = ActivationBalancer(channel_dim=0,
|
m = ActivationBalancer(
|
||||||
min_positive=0.0, max_positive=1.0,
|
channel_dim=0,
|
||||||
max_factor=0.2,
|
min_positive=0.0,
|
||||||
min_abs=0.2, max_abs=0.8)
|
max_positive=1.0,
|
||||||
|
max_factor=0.2,
|
||||||
|
min_abs=0.2,
|
||||||
|
max_abs=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||||
|
|
||||||
@ -558,8 +680,8 @@ def _test_basic_norm():
|
|||||||
y = m(x)
|
y = m(x)
|
||||||
|
|
||||||
assert y.shape == x.shape
|
assert y.shape == x.shape
|
||||||
x_rms = (x**2).mean().sqrt()
|
x_rms = (x ** 2).mean().sqrt()
|
||||||
y_rms = (y**2).mean().sqrt()
|
y_rms = (y ** 2).mean().sqrt()
|
||||||
print("x rms = ", x_rms)
|
print("x rms = ", x_rms)
|
||||||
print("y rms = ", y_rms)
|
print("y rms = ", y_rms)
|
||||||
assert y_rms < x_rms
|
assert y_rms < x_rms
|
||||||
@ -573,7 +695,7 @@ def _test_double_swish_deriv():
|
|||||||
torch.autograd.gradcheck(m, x)
|
torch.autograd.gradcheck(m, x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
|
@ -29,11 +29,11 @@ from torch.cuda.amp import GradScaler
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
# use duck typing for LRScheduler since we have different possibilities, see
|
# use duck typing for LRScheduler since we have different possibilities, see
|
||||||
# our class LRScheduler.
|
# our class LRScheduler.
|
||||||
LRSchedulerType = object
|
LRSchedulerType = object
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
filename: Path,
|
filename: Path,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user