mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Code style check for librispeech pruned transducer stateless2 (#308)
This commit is contained in:
parent
8cb727e24a
commit
93c60a9d30
2
.flake8
2
.flake8
@ -7,6 +7,8 @@ per-file-ignores =
|
|||||||
egs/librispeech/ASR/*/conformer.py: E501,
|
egs/librispeech/ASR/*/conformer.py: E501,
|
||||||
egs/aishell/ASR/*/conformer.py: E501,
|
egs/aishell/ASR/*/conformer.py: E501,
|
||||||
egs/tedlium3/ASR/*/conformer.py: E501,
|
egs/tedlium3/ASR/*/conformer.py: E501,
|
||||||
|
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
|
||||||
|
|
||||||
# invalid escape sequence (cause by tex formular), W605
|
# invalid escape sequence (cause by tex formular), W605
|
||||||
icefall/utils.py: E501, W605
|
icefall/utils.py: E501, W605
|
||||||
|
|
||||||
|
@ -93,7 +93,9 @@ def fast_beam_search(
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False
|
current_encoder_out.unsqueeze(2),
|
||||||
|
decoder_out.unsqueeze(1),
|
||||||
|
project_input=False,
|
||||||
)
|
)
|
||||||
logits = logits.squeeze(1).squeeze(1)
|
logits = logits.squeeze(1).squeeze(1)
|
||||||
log_probs = logits.log_softmax(dim=-1)
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
@ -140,7 +142,6 @@ def greedy_search(
|
|||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
|
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
t = 0
|
t = 0
|
||||||
hyp = [blank_id] * context_size
|
hyp = [blank_id] * context_size
|
||||||
@ -163,9 +164,9 @@ def greedy_search(
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
logits = model.joiner(current_encoder_out,
|
logits = model.joiner(
|
||||||
decoder_out.unsqueeze(1),
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
project_input=False)
|
)
|
||||||
# logits is (1, 1, 1, vocab_size)
|
# logits is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
@ -228,8 +229,9 @@ def greedy_search_batch(
|
|||||||
for t in range(T):
|
for t in range(T):
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1),
|
logits = model.joiner(
|
||||||
project_input=False)
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
|
)
|
||||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
@ -466,7 +468,6 @@ def modified_beam_search(
|
|||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||||
|
|
||||||
|
|
||||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
# as index, so we use `to(torch.int64)` below.
|
# as index, so we use `to(torch.int64)` below.
|
||||||
current_encoder_out = torch.index_select(
|
current_encoder_out = torch.index_select(
|
||||||
@ -720,7 +721,7 @@ def beam_search(
|
|||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out,
|
current_encoder_out,
|
||||||
decoder_out.unsqueeze(1),
|
decoder_out.unsqueeze(1),
|
||||||
project_input=False
|
project_input=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Scale the blank posterior
|
# TODO(fangjun): Scale the blank posterior
|
||||||
|
@ -16,13 +16,20 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from encoder_interface import EncoderInterface
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple
|
||||||
from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import (
|
||||||
|
ActivationBalancer,
|
||||||
|
BasicNorm,
|
||||||
|
DoubleSwish,
|
||||||
|
ScaledConv1d,
|
||||||
|
ScaledConv2d,
|
||||||
|
ScaledLinear,
|
||||||
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
@ -42,6 +49,7 @@ class Conformer(EncoderInterface):
|
|||||||
cnn_module_kernel (int): Kernel size of convolution module
|
cnn_module_kernel (int): Kernel size of convolution module
|
||||||
vgg_frontend (bool): whether to use vgg frontend.
|
vgg_frontend (bool): whether to use vgg frontend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
@ -80,9 +88,8 @@ class Conformer(EncoderInterface):
|
|||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -112,8 +119,9 @@ class Conformer(EncoderInterface):
|
|||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
x = self.encoder(
|
||||||
warmup=warmup) # (T, N, C)
|
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
@ -176,18 +184,15 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
self.balancer = ActivationBalancer(channel_dim=-1,
|
self.balancer = ActivationBalancer(
|
||||||
min_positive=0.45,
|
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||||
max_positive=0.55,
|
)
|
||||||
max_abs=6.0)
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -220,14 +225,17 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||||
# completely bypass it.
|
# completely bypass it.
|
||||||
if self.training:
|
if self.training:
|
||||||
alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1
|
alpha = (
|
||||||
|
warmup_scale
|
||||||
|
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
||||||
|
else 0.1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
alpha = 1.0
|
alpha = 1.0
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
src,
|
src,
|
||||||
@ -248,7 +256,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
if alpha != 1.0:
|
if alpha != 1.0:
|
||||||
src = alpha * src + (1-alpha) * src_orig
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
@ -275,14 +283,13 @@ class ConformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0
|
warmup: float = 1.0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -302,8 +309,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
num_layers = len(self.layers)
|
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
@ -428,7 +433,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
|
self.out_proj = ScaledLinear(
|
||||||
|
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
||||||
|
)
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||||
@ -621,7 +628,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
q, k, v = nn.functional.linear(
|
||||||
|
query, in_proj_weight, in_proj_bias
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -653,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = embed_dim
|
_start = embed_dim
|
||||||
@ -672,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
v = nn.functional.linear(value, _w, _b)
|
v = nn.functional.linear(value, _w, _b)
|
||||||
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
attn_mask.dtype == torch.float32
|
attn_mask.dtype == torch.float32
|
||||||
@ -864,9 +871,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
# it will be in a better position to start learning something, i.e. to latch onto
|
||||||
# the correct range.
|
# the correct range.
|
||||||
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
|
self.deriv_balancer1 = ActivationBalancer(
|
||||||
min_positive=0.05,
|
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||||
max_positive=1.0)
|
)
|
||||||
|
|
||||||
self.depthwise_conv = ScaledConv1d(
|
self.depthwise_conv = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
@ -878,9 +885,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
|
self.deriv_balancer2 = ActivationBalancer(
|
||||||
min_positive=0.05,
|
channel_dim=1, min_positive=0.05, max_positive=1.0
|
||||||
max_positive=1.0)
|
)
|
||||||
|
|
||||||
self.activation = DoubleSwish()
|
self.activation = DoubleSwish()
|
||||||
|
|
||||||
@ -891,7 +898,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
initial_scale=0.25
|
initial_scale=0.25,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
@ -924,7 +931,6 @@ class ConvolutionModule(nn.Module):
|
|||||||
return x.permute(2, 0, 1)
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
"""Convolutional 2D subsampling (to 1/4 length).
|
"""Convolutional 2D subsampling (to 1/4 length).
|
||||||
|
|
||||||
@ -936,11 +942,14 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int,
|
def __init__(
|
||||||
out_channels: int,
|
self,
|
||||||
layer1_channels: int = 8,
|
in_channels: int,
|
||||||
layer2_channels: int = 32,
|
out_channels: int,
|
||||||
layer3_channels: int = 128) -> None:
|
layer1_channels: int = 8,
|
||||||
|
layer2_channels: int = 32,
|
||||||
|
layer3_channels: int = 128,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
in_channels:
|
in_channels:
|
||||||
@ -958,34 +967,41 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=1, out_channels=layer1_channels,
|
in_channels=1,
|
||||||
kernel_size=3, padding=1,
|
out_channels=layer1_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=layer1_channels, out_channels=layer2_channels,
|
in_channels=layer1_channels,
|
||||||
kernel_size=3, stride=2,
|
out_channels=layer2_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=layer2_channels, out_channels=layer3_channels,
|
in_channels=layer2_channels,
|
||||||
kernel_size=3, stride=2,
|
out_channels=layer3_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
|
self.out = ScaledLinear(
|
||||||
|
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||||
|
)
|
||||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
# itself has learned scale, so the extra degree of freedom is not
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
# needed.
|
# needed.
|
||||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||||
# constrain median of output to be close to zero.
|
# constrain median of output to be close to zero.
|
||||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
self.out_balancer = ActivationBalancer(
|
||||||
min_positive=0.45,
|
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||||
max_positive=0.55)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Subsample x.
|
"""Subsample x.
|
||||||
@ -1009,13 +1025,14 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
if __name__ == '__main__':
|
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
f = c(
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
warmup=0.5)
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
|
warmup=0.5,
|
||||||
|
)
|
||||||
|
@ -17,9 +17,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from scaling import ScaledConv1d, ScaledEmbedding
|
||||||
from typing import Optional
|
|
||||||
from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
|
@ -16,15 +16,17 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
encoder_dim: int,
|
self,
|
||||||
decoder_dim: int,
|
encoder_dim: int,
|
||||||
joiner_dim: int,
|
decoder_dim: int,
|
||||||
vocab_size: int):
|
joiner_dim: int,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
|
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
|
||||||
@ -32,8 +34,10 @@ class Joiner(nn.Module):
|
|||||||
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor,
|
self,
|
||||||
project_input: bool = True
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
project_input: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -52,7 +56,9 @@ class Joiner(nn.Module):
|
|||||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||||
|
decoder_out
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ class Transducer(nn.Module):
|
|||||||
encoder_dim: int,
|
encoder_dim: int,
|
||||||
decoder_dim: int,
|
decoder_dim: int,
|
||||||
joiner_dim: int,
|
joiner_dim: int,
|
||||||
vocab_size: int
|
vocab_size: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -48,11 +48,11 @@ class Transducer(nn.Module):
|
|||||||
`logit_lens` of shape (N,).
|
`logit_lens` of shape (N,).
|
||||||
decoder:
|
decoder:
|
||||||
It is the prediction network in the paper. Its input shape
|
It is the prediction network in the paper. Its input shape
|
||||||
is (N, U) and its output shape is (N, U, decoder_dim). It should contain
|
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||||
one attribute: `blank_id`.
|
It should contain one attribute: `blank_id`.
|
||||||
joiner:
|
joiner:
|
||||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its
|
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||||
output shape is (N, T, U, vocab_size). Note that its output contains
|
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -63,8 +63,9 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size,
|
self.simple_am_proj = ScaledLinear(
|
||||||
initial_speed=0.5)
|
encoder_dim, vocab_size, initial_speed=0.5
|
||||||
|
)
|
||||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -141,8 +142,8 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
lm=self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am=self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
@ -170,15 +171,14 @@ class Transducer(nn.Module):
|
|||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
am=self.joiner.encoder_proj(encoder_out),
|
am=self.joiner.encoder_proj(encoder_out),
|
||||||
lm=self.joiner.decoder_proj(decoder_out),
|
lm=self.joiner.decoder_proj(decoder_out),
|
||||||
ranges=ranges
|
ranges=ranges,
|
||||||
)
|
)
|
||||||
|
|
||||||
# logits : [B, T, prune_range, vocab_size]
|
# logits : [B, T, prune_range, vocab_size]
|
||||||
|
|
||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned,
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
project_input=False)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
|
@ -15,11 +15,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import random
|
from typing import List, Optional, Union
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -59,24 +57,41 @@ class Eve(Optimizer):
|
|||||||
https://openreview.net/forum?id=ryQu7f-RZ
|
https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8,
|
def __init__(
|
||||||
weight_decay=1e-3, target_rms=0.1):
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-3,
|
||||||
|
betas=(0.9, 0.98),
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=1e-3,
|
||||||
|
target_rms=0.1,
|
||||||
|
):
|
||||||
|
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
if not 0.0 <= betas[0] < 1.0:
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
raise ValueError(
|
||||||
|
"Invalid beta parameter at index 0: {}".format(betas[0])
|
||||||
|
)
|
||||||
if not 0.0 <= betas[1] < 1.0:
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
raise ValueError(
|
||||||
|
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||||
|
)
|
||||||
if not 0 <= weight_decay <= 0.1:
|
if not 0 <= weight_decay <= 0.1:
|
||||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
raise ValueError(
|
||||||
|
"Invalid weight_decay value: {}".format(weight_decay)
|
||||||
|
)
|
||||||
if not 0 < target_rms <= 10.0:
|
if not 0 < target_rms <= 10.0:
|
||||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
defaults = dict(
|
||||||
weight_decay=weight_decay,
|
lr=lr,
|
||||||
target_rms=target_rms)
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
target_rms=target_rms,
|
||||||
|
)
|
||||||
super(Eve, self).__init__(params, defaults)
|
super(Eve, self).__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
@ -96,83 +111,98 @@ class Eve(Optimizer):
|
|||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Perform optimization step
|
# Perform optimization step
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError('AdamW does not support sparse gradients')
|
raise RuntimeError(
|
||||||
|
"AdamW does not support sparse gradients"
|
||||||
|
)
|
||||||
|
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
|
|
||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state["step"] = 0
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
state["exp_avg"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
|
|
||||||
beta1, beta2 = group['betas']
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
state['step'] += 1
|
state["step"] += 1
|
||||||
bias_correction1 = 1 - beta1 ** state['step']
|
bias_correction1 = 1 - beta1 ** state["step"]
|
||||||
bias_correction2 = 1 - beta2 ** state['step']
|
bias_correction2 = 1 - beta2 ** state["step"]
|
||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
# Decay the first and second moment running average coefficient
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||||
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps'])
|
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
|
||||||
|
group["eps"]
|
||||||
|
)
|
||||||
|
|
||||||
step_size = group['lr'] / bias_correction1
|
step_size = group["lr"] / bias_correction1
|
||||||
target_rms = group['target_rms']
|
target_rms = group["target_rms"]
|
||||||
weight_decay = group['weight_decay']
|
weight_decay = group["weight_decay"]
|
||||||
delta = exp_avg / denom
|
|
||||||
|
|
||||||
if p.numel() > 1:
|
if p.numel() > 1:
|
||||||
# avoid applying this weight-decay on "scaling factors"
|
# avoid applying this weight-decay on "scaling factors"
|
||||||
# (which are scalar).
|
# (which are scalar).
|
||||||
is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5)))
|
is_above_target_rms = p.norm() > (
|
||||||
|
target_rms * (p.numel() ** 0.5)
|
||||||
|
)
|
||||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
p.mul_(1 - (weight_decay * is_above_target_rms))
|
||||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class LRScheduler(object):
|
class LRScheduler(object):
|
||||||
"""
|
"""
|
||||||
Base-class for learning rate schedulers where the learning-rate depends on both the
|
Base-class for learning rate schedulers where the learning-rate depends on both the
|
||||||
batch and the epoch.
|
batch and the epoch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
||||||
# Attach optimizer
|
# Attach optimizer
|
||||||
if not isinstance(optimizer, Optimizer):
|
if not isinstance(optimizer, Optimizer):
|
||||||
raise TypeError('{} is not an Optimizer'.format(
|
raise TypeError(
|
||||||
type(optimizer).__name__))
|
"{} is not an Optimizer".format(type(optimizer).__name__)
|
||||||
|
)
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group.setdefault('initial_lr', group['lr'])
|
group.setdefault("initial_lr", group["lr"])
|
||||||
|
|
||||||
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
|
self.base_lrs = [
|
||||||
|
group["initial_lr"] for group in optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.batch = 0
|
self.batch = 0
|
||||||
|
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
"""Returns the state of the scheduler as a :class:`dict`.
|
"""Returns the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
It contains an entry for every variable in self.__dict__ which
|
It contains an entry for every variable in self.__dict__ which
|
||||||
is not the optimizer.
|
is not the optimizer.
|
||||||
"""
|
"""
|
||||||
return {'base_lrs': self.base_lrs,
|
return {
|
||||||
'epoch': self.epoch,
|
"base_lrs": self.base_lrs,
|
||||||
'batch': self.batch}
|
"epoch": self.epoch,
|
||||||
|
"batch": self.batch,
|
||||||
|
}
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
"""Loads the schedulers state.
|
"""Loads the schedulers state.
|
||||||
@ -184,8 +214,7 @@ class LRScheduler(object):
|
|||||||
self.__dict__.update(state_dict)
|
self.__dict__.update(state_dict)
|
||||||
|
|
||||||
def get_last_lr(self) -> List[float]:
|
def get_last_lr(self) -> List[float]:
|
||||||
""" Return last computed learning rate by current scheduler. Will be a list of float.
|
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
||||||
"""
|
|
||||||
return self._last_lr
|
return self._last_lr
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
@ -194,7 +223,6 @@ class LRScheduler(object):
|
|||||||
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def step_batch(self, batch: Optional[int] = None) -> None:
|
def step_batch(self, batch: Optional[int] = None) -> None:
|
||||||
# Step the batch index, or just set it. If `batch` is specified, it
|
# Step the batch index, or just set it. If `batch` is specified, it
|
||||||
# must be the batch index from the start of training, i.e. summed over
|
# must be the batch index from the start of training, i.e. summed over
|
||||||
@ -217,24 +245,23 @@ class LRScheduler(object):
|
|||||||
self.epoch = self.epoch + 1
|
self.epoch = self.epoch + 1
|
||||||
self._set_lrs()
|
self._set_lrs()
|
||||||
|
|
||||||
|
|
||||||
def _set_lrs(self):
|
def _set_lrs(self):
|
||||||
values = self.get_lr()
|
values = self.get_lr()
|
||||||
assert len(values) == len(self.optimizer.param_groups)
|
assert len(values) == len(self.optimizer.param_groups)
|
||||||
|
|
||||||
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||||
param_group, lr = data
|
param_group, lr = data
|
||||||
param_group['lr'] = lr
|
param_group["lr"] = lr
|
||||||
self.print_lr(self.verbose, i, lr)
|
self.print_lr(self.verbose, i, lr)
|
||||||
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
|
|
||||||
def print_lr(self, is_verbose, group, lr):
|
def print_lr(self, is_verbose, group, lr):
|
||||||
"""Display the current learning rate.
|
"""Display the current learning rate."""
|
||||||
"""
|
|
||||||
if is_verbose:
|
if is_verbose:
|
||||||
print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate'
|
print(
|
||||||
f' of group {group} to {lr:.4e}.')
|
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||||
|
f" of group {group} to {lr:.4e}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Eden(LRScheduler):
|
class Eden(LRScheduler):
|
||||||
@ -254,18 +281,27 @@ class Eden(LRScheduler):
|
|||||||
20 to 40 epochs, but may need smaller number if dataset is huge
|
20 to 40 epochs, but may need smaller number if dataset is huge
|
||||||
and you will do few epochs.
|
and you will do few epochs.
|
||||||
"""
|
"""
|
||||||
def __init__(self, optimizer: Optimizer,
|
|
||||||
lr_batches: Union[int, float],
|
def __init__(
|
||||||
lr_epochs: Union[int, float],
|
self,
|
||||||
verbose: bool = False):
|
optimizer: Optimizer,
|
||||||
|
lr_batches: Union[int, float],
|
||||||
|
lr_epochs: Union[int, float],
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
super(Eden, self).__init__(optimizer, verbose)
|
super(Eden, self).__init__(optimizer, verbose)
|
||||||
self.lr_batches = lr_batches
|
self.lr_batches = lr_batches
|
||||||
self.lr_epochs = lr_epochs
|
self.lr_epochs = lr_epochs
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 *
|
factor = (
|
||||||
(((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25))
|
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
|
||||||
return [ x * factor for x in self.base_lrs ]
|
) ** -0.25 * (
|
||||||
|
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
|
||||||
|
** -0.25
|
||||||
|
)
|
||||||
|
return [x * factor for x in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
def _test_eden():
|
def _test_eden():
|
||||||
m = torch.nn.Linear(100, 100)
|
m = torch.nn.Linear(100, 100)
|
||||||
@ -290,5 +326,6 @@ def _test_eden():
|
|||||||
print("last lr = ", scheduler.get_last_lr())
|
print("last lr = ", scheduler.get_last_lr())
|
||||||
print("state dict = ", scheduler.state_dict())
|
print("state dict = ", scheduler.state_dict())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
_test_eden()
|
_test_eden()
|
||||||
|
@ -15,54 +15,86 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import collections
|
||||||
|
from itertools import repeat
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Tuple, Optional
|
|
||||||
|
|
||||||
|
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, collections.Iterable):
|
||||||
|
return x
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
_single = _ntuple(1)
|
||||||
|
_pair = _ntuple(2)
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancerFunction(torch.autograd.Function):
|
class ActivationBalancerFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor,
|
def forward(
|
||||||
channel_dim: int,
|
ctx,
|
||||||
min_positive: float, # e.g. 0.05
|
x: Tensor,
|
||||||
max_positive: float, # e.g. 0.95
|
channel_dim: int,
|
||||||
max_factor: float, # e.g. 0.01
|
min_positive: float, # e.g. 0.05
|
||||||
min_abs: float, # e.g. 0.2
|
max_positive: float, # e.g. 0.95
|
||||||
max_abs: float, # e.g. 100.0
|
max_factor: float, # e.g. 0.01
|
||||||
|
min_abs: float, # e.g. 0.2
|
||||||
|
max_abs: float, # e.g. 100.0
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if x.requires_grad:
|
if x.requires_grad:
|
||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim += x.ndim
|
channel_dim += x.ndim
|
||||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||||
xgt0 = x > 0
|
xgt0 = x > 0
|
||||||
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
|
proportion_positive = torch.mean(
|
||||||
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
|
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
||||||
if min_positive != 0.0 else 0.0)
|
)
|
||||||
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
|
factor1 = (
|
||||||
if max_positive != 1.0 else 0.0)
|
(min_positive - proportion_positive).relu()
|
||||||
|
* (max_factor / min_positive)
|
||||||
|
if min_positive != 0.0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
factor2 = (
|
||||||
|
(proportion_positive - max_positive).relu()
|
||||||
|
* (max_factor / (max_positive - 1.0))
|
||||||
|
if max_positive != 1.0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
factor = factor1 + factor2
|
factor = factor1 + factor2
|
||||||
if isinstance(factor, float):
|
if isinstance(factor, float):
|
||||||
factor = torch.zeros_like(proportion_positive)
|
factor = torch.zeros_like(proportion_positive)
|
||||||
|
|
||||||
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
||||||
below_threshold = (mean_abs < min_abs)
|
below_threshold = mean_abs < min_abs
|
||||||
above_threshold = (mean_abs > max_abs)
|
above_threshold = mean_abs > max_abs
|
||||||
|
|
||||||
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
|
ctx.save_for_backward(
|
||||||
|
factor, xgt0, below_threshold, above_threshold
|
||||||
|
)
|
||||||
ctx.max_factor = max_factor
|
ctx.max_factor = max_factor
|
||||||
ctx.sum_dims = sum_dims
|
ctx.sum_dims = sum_dims
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
|
def backward(
|
||||||
|
ctx, x_grad: Tensor
|
||||||
|
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
||||||
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
||||||
dtype = x_grad.dtype
|
dtype = x_grad.dtype
|
||||||
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
|
scale_factor = (
|
||||||
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
|
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
||||||
|
* (xgt0.to(dtype) - 0.5)
|
||||||
|
* (ctx.max_factor * 2.0)
|
||||||
|
)
|
||||||
|
|
||||||
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
||||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||||
@ -95,29 +127,31 @@ class BasicNorm(torch.nn.Module):
|
|||||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||||
at the initial value.
|
at the initial value.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
num_channels: int,
|
def __init__(
|
||||||
channel_dim: int = -1, # CAUTION: see documentation.
|
self,
|
||||||
eps: float = 0.25,
|
num_channels: int,
|
||||||
learn_eps: bool = True) -> None:
|
channel_dim: int = -1, # CAUTION: see documentation.
|
||||||
|
eps: float = 0.25,
|
||||||
|
learn_eps: bool = True,
|
||||||
|
) -> None:
|
||||||
super(BasicNorm, self).__init__()
|
super(BasicNorm, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
if learn_eps:
|
if learn_eps:
|
||||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||||
else:
|
else:
|
||||||
self.register_buffer('eps', torch.tensor(eps).log().detach())
|
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
|
scales = (
|
||||||
self.eps.exp()) ** -0.5
|
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||||
|
+ self.eps.exp()
|
||||||
|
) ** -0.5
|
||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledLinear(nn.Linear):
|
class ScaledLinear(nn.Linear):
|
||||||
"""
|
"""
|
||||||
A modified version of nn.Linear where the parameters are scaled before
|
A modified version of nn.Linear where the parameters are scaled before
|
||||||
@ -143,19 +177,25 @@ class ScaledLinear(nn.Linear):
|
|||||||
Alternatively you can set it to more than 1 if you want it to
|
Alternatively you can set it to more than 1 if you want it to
|
||||||
initially train faster. Must be greater than 0.
|
initially train faster. Must be greater than 0.
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args,
|
|
||||||
initial_scale: float = 1.0,
|
def __init__(
|
||||||
initial_speed: float = 1.0,
|
self,
|
||||||
**kwargs):
|
*args,
|
||||||
|
initial_scale: float = 1.0,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter("bias_scale", None)
|
||||||
|
|
||||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
|
self._reset_parameters(
|
||||||
|
initial_speed
|
||||||
|
) # Overrides the reset_parameters in nn.Linear
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
@ -172,28 +212,33 @@ class ScaledLinear(nn.Linear):
|
|||||||
return self.weight * self.weight_scale.exp()
|
return self.weight * self.weight_scale.exp()
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
return (None if self.bias is None else
|
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||||
self.bias * self.bias_scale.exp())
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
return torch.nn.functional.linear(input, self.get_weight(),
|
return torch.nn.functional.linear(
|
||||||
self.get_bias())
|
input, self.get_weight(), self.get_bias()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScaledConv1d(nn.Conv1d):
|
class ScaledConv1d(nn.Conv1d):
|
||||||
# See docs for ScaledLinear
|
# See docs for ScaledLinear
|
||||||
def __init__(self, *args,
|
def __init__(
|
||||||
initial_scale: float = 1.0,
|
self,
|
||||||
initial_speed: float = 1.0,
|
*args,
|
||||||
**kwargs):
|
initial_scale: float = 1.0,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter("bias_scale", None)
|
||||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
self._reset_parameters(
|
||||||
|
initial_speed
|
||||||
|
) # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
@ -206,39 +251,58 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight_scale += torch.tensor(scale / std).log()
|
||||||
|
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.weight * self.weight_scale.exp()
|
return self.weight * self.weight_scale.exp()
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
return (None if self.bias is None else
|
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||||
self.bias * self.bias_scale.exp())
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
F = torch.nn.functional
|
F = torch.nn.functional
|
||||||
if self.padding_mode != 'zeros':
|
if self.padding_mode != "zeros":
|
||||||
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
return F.conv1d(
|
||||||
self.get_weight(), self.get_bias(), self.stride,
|
F.pad(
|
||||||
_single(0), self.dilation, self.groups)
|
input,
|
||||||
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
|
self._reversed_padding_repeated_twice,
|
||||||
self.padding, self.dilation, self.groups)
|
mode=self.padding_mode,
|
||||||
|
),
|
||||||
|
self.get_weight(),
|
||||||
|
self.get_bias(),
|
||||||
|
self.stride,
|
||||||
|
_single(0),
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
return F.conv1d(
|
||||||
|
input,
|
||||||
|
self.get_weight(),
|
||||||
|
self.get_bias(),
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScaledConv2d(nn.Conv2d):
|
class ScaledConv2d(nn.Conv2d):
|
||||||
# See docs for ScaledLinear
|
# See docs for ScaledLinear
|
||||||
def __init__(self, *args,
|
def __init__(
|
||||||
initial_scale: float = 1.0,
|
self,
|
||||||
initial_speed: float = 1.0,
|
*args,
|
||||||
**kwargs):
|
initial_scale: float = 1.0,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter("bias_scale", None)
|
||||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
self._reset_parameters(
|
||||||
|
initial_speed
|
||||||
|
) # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
@ -251,29 +315,42 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += torch.tensor(scale / std).log()
|
self.weight_scale += torch.tensor(scale / std).log()
|
||||||
|
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.weight * self.weight_scale.exp()
|
return self.weight * self.weight_scale.exp()
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
return (None if self.bias is None else
|
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||||
self.bias * self.bias_scale.exp())
|
|
||||||
|
|
||||||
def _conv_forward(self, input, weight):
|
def _conv_forward(self, input, weight):
|
||||||
F = torch.nn.functional
|
F = torch.nn.functional
|
||||||
if self.padding_mode != 'zeros':
|
if self.padding_mode != "zeros":
|
||||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
return F.conv2d(
|
||||||
weight, self.get_bias(), self.stride,
|
F.pad(
|
||||||
_pair(0), self.dilation, self.groups)
|
input,
|
||||||
return F.conv2d(input, weight, self.get_bias(), self.stride,
|
self._reversed_padding_repeated_twice,
|
||||||
self.padding, self.dilation, self.groups)
|
mode=self.padding_mode,
|
||||||
|
),
|
||||||
|
weight,
|
||||||
|
self.get_bias(),
|
||||||
|
self.stride,
|
||||||
|
_pair(0),
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
return F.conv2d(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
self.get_bias(),
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
return self._conv_forward(input, self.get_weight())
|
return self._conv_forward(input, self.get_weight())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancer(torch.nn.Module):
|
class ActivationBalancer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||||
@ -302,12 +379,16 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
we allow, before we start to modify the derivatives to prevent
|
we allow, before we start to modify the derivatives to prevent
|
||||||
this.
|
this.
|
||||||
"""
|
"""
|
||||||
def __init__(self, channel_dim: int,
|
|
||||||
min_positive: float = 0.05,
|
def __init__(
|
||||||
max_positive: float = 0.95,
|
self,
|
||||||
max_factor: float = 0.01,
|
channel_dim: int,
|
||||||
min_abs: float = 0.2,
|
min_positive: float = 0.05,
|
||||||
max_abs: float = 100.0):
|
max_positive: float = 0.95,
|
||||||
|
max_factor: float = 0.01,
|
||||||
|
min_abs: float = 0.2,
|
||||||
|
max_abs: float = 100.0,
|
||||||
|
):
|
||||||
super(ActivationBalancer, self).__init__()
|
super(ActivationBalancer, self).__init__()
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.min_positive = min_positive
|
self.min_positive = min_positive
|
||||||
@ -317,10 +398,15 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return ActivationBalancerFunction.apply(x, self.channel_dim,
|
return ActivationBalancerFunction.apply(
|
||||||
self.min_positive, self.max_positive,
|
x,
|
||||||
self.max_factor, self.min_abs,
|
self.channel_dim,
|
||||||
self.max_abs)
|
self.min_positive,
|
||||||
|
self.max_positive,
|
||||||
|
self.max_factor,
|
||||||
|
self.min_abs,
|
||||||
|
self.max_abs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DoubleSwishFunction(torch.autograd.Function):
|
class DoubleSwishFunction(torch.autograd.Function):
|
||||||
@ -338,6 +424,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
= double_swish(x) * (1-s(x)) + s(x)
|
= double_swish(x) * (1-s(x)) + s(x)
|
||||||
... so we just need to remember s(x) but not x itself.
|
... so we just need to remember s(x) but not x itself.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
@ -349,18 +436,17 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
s, y = ctx.saved_tensors
|
s, y = ctx.saved_tensors
|
||||||
return (y * (1-s) + s) * y_grad
|
return (y * (1 - s) + s) * y_grad
|
||||||
|
|
||||||
|
|
||||||
class DoubleSwish(torch.nn.Module):
|
class DoubleSwish(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
that we approximate closely with x * sigmoid(x-1).
|
that we approximate closely with x * sigmoid(x-1).
|
||||||
"""
|
"""
|
||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledEmbedding(nn.Module):
|
class ScaledEmbedding(nn.Module):
|
||||||
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
||||||
on the parameters. Note: due to how we initialize it, it's best used with
|
on the parameters. Note: due to how we initialize it, it's best used with
|
||||||
@ -443,8 +529,13 @@ class ScaledEmbedding(nn.Module):
|
|||||||
[-0.1655, 0.9897, 0.0635]]])
|
[-0.1655, 0.9897, 0.0635]]])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
__constants__ = [
|
||||||
'scale_grad_by_freq', 'sparse']
|
"num_embeddings",
|
||||||
|
"embedding_dim",
|
||||||
|
"padding_idx",
|
||||||
|
"scale_grad_by_freq",
|
||||||
|
"sparse",
|
||||||
|
]
|
||||||
|
|
||||||
num_embeddings: int
|
num_embeddings: int
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
@ -453,33 +544,41 @@ class ScaledEmbedding(nn.Module):
|
|||||||
weight: Tensor
|
weight: Tensor
|
||||||
sparse: bool
|
sparse: bool
|
||||||
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
def __init__(
|
||||||
scale_grad_by_freq: bool = False,
|
self,
|
||||||
sparse: bool = False,
|
num_embeddings: int,
|
||||||
initial_speed: float = 1.0) -> None:
|
embedding_dim: int,
|
||||||
|
padding_idx: Optional[int] = None,
|
||||||
|
scale_grad_by_freq: bool = False,
|
||||||
|
sparse: bool = False,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
) -> None:
|
||||||
super(ScaledEmbedding, self).__init__()
|
super(ScaledEmbedding, self).__init__()
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
if padding_idx is not None:
|
if padding_idx is not None:
|
||||||
if padding_idx > 0:
|
if padding_idx > 0:
|
||||||
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
assert (
|
||||||
|
padding_idx < self.num_embeddings
|
||||||
|
), "Padding_idx must be within num_embeddings"
|
||||||
elif padding_idx < 0:
|
elif padding_idx < 0:
|
||||||
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
assert (
|
||||||
|
padding_idx >= -self.num_embeddings
|
||||||
|
), "Padding_idx must be within num_embeddings"
|
||||||
padding_idx = self.num_embeddings + padding_idx
|
padding_idx = self.num_embeddings + padding_idx
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.scale_grad_by_freq = scale_grad_by_freq
|
self.scale_grad_by_freq = scale_grad_by_freq
|
||||||
|
|
||||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
||||||
self.sparse = sparse
|
self.sparse = sparse
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||||
self.reset_parameters(initial_speed)
|
self.reset_parameters(initial_speed)
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
nn.init.normal_(self.weight, std=std)
|
nn.init.normal_(self.weight, std=std)
|
||||||
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
|
||||||
|
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -489,36 +588,53 @@ class ScaledEmbedding(nn.Module):
|
|||||||
F = torch.nn.functional
|
F = torch.nn.functional
|
||||||
scale = self.scale.exp()
|
scale = self.scale.exp()
|
||||||
if input.numel() < self.num_embeddings:
|
if input.numel() < self.num_embeddings:
|
||||||
return F.embedding(
|
return (
|
||||||
input, self.weight, self.padding_idx,
|
F.embedding(
|
||||||
None, 2.0, # None, 2.0 relate to normalization
|
input,
|
||||||
self.scale_grad_by_freq, self.sparse) * scale
|
self.weight,
|
||||||
|
self.padding_idx,
|
||||||
|
None,
|
||||||
|
2.0, # None, 2.0 relate to normalization
|
||||||
|
self.scale_grad_by_freq,
|
||||||
|
self.sparse,
|
||||||
|
)
|
||||||
|
* scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return F.embedding(
|
return F.embedding(
|
||||||
input, self.weight * scale, self.padding_idx,
|
input,
|
||||||
None, 2.0, # None, 2.0 relates to normalization
|
self.weight * scale,
|
||||||
self.scale_grad_by_freq, self.sparse)
|
self.padding_idx,
|
||||||
|
None,
|
||||||
|
2.0, # None, 2.0 relates to normalization
|
||||||
|
self.scale_grad_by_freq,
|
||||||
|
self.sparse,
|
||||||
|
)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = '{num_embeddings}, {embedding_dim}, scale={scale}'
|
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None:
|
||||||
s += ', padding_idx={padding_idx}'
|
s += ", padding_idx={padding_idx}"
|
||||||
if self.scale_grad_by_freq is not False:
|
if self.scale_grad_by_freq is not False:
|
||||||
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
||||||
if self.sparse is not False:
|
if self.sparse is not False:
|
||||||
s += ', sparse=True'
|
s += ", sparse=True"
|
||||||
return s.format(**self.__dict__)
|
return s.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_sign():
|
def _test_activation_balancer_sign():
|
||||||
channel_dim = 0
|
|
||||||
probs = torch.arange(0, 1, 0.01)
|
probs = torch.arange(0, 1, 0.01)
|
||||||
N = 1000
|
N = 1000
|
||||||
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
|
m = ActivationBalancer(
|
||||||
max_factor=0.2, min_abs=0.0)
|
channel_dim=0,
|
||||||
|
min_positive=0.05,
|
||||||
|
max_positive=0.95,
|
||||||
|
max_factor=0.2,
|
||||||
|
min_abs=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
||||||
|
|
||||||
@ -528,17 +644,23 @@ def _test_activation_balancer_sign():
|
|||||||
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
||||||
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_magnitude():
|
def _test_activation_balancer_magnitude():
|
||||||
channel_dim = 0
|
|
||||||
magnitudes = torch.arange(0, 1, 0.01)
|
magnitudes = torch.arange(0, 1, 0.01)
|
||||||
N = 1000
|
N = 1000
|
||||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
||||||
|
-1
|
||||||
|
)
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
m = ActivationBalancer(channel_dim=0,
|
m = ActivationBalancer(
|
||||||
min_positive=0.0, max_positive=1.0,
|
channel_dim=0,
|
||||||
max_factor=0.2,
|
min_positive=0.0,
|
||||||
min_abs=0.2, max_abs=0.8)
|
max_positive=1.0,
|
||||||
|
max_factor=0.2,
|
||||||
|
min_abs=0.2,
|
||||||
|
max_abs=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
||||||
|
|
||||||
@ -558,8 +680,8 @@ def _test_basic_norm():
|
|||||||
y = m(x)
|
y = m(x)
|
||||||
|
|
||||||
assert y.shape == x.shape
|
assert y.shape == x.shape
|
||||||
x_rms = (x**2).mean().sqrt()
|
x_rms = (x ** 2).mean().sqrt()
|
||||||
y_rms = (y**2).mean().sqrt()
|
y_rms = (y ** 2).mean().sqrt()
|
||||||
print("x rms = ", x_rms)
|
print("x rms = ", x_rms)
|
||||||
print("y rms = ", y_rms)
|
print("y rms = ", y_rms)
|
||||||
assert y_rms < x_rms
|
assert y_rms < x_rms
|
||||||
@ -573,7 +695,7 @@ def _test_double_swish_deriv():
|
|||||||
torch.autograd.gradcheck(m, x)
|
torch.autograd.gradcheck(m, x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
|
@ -45,16 +45,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import optim
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import optim # from .
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
@ -65,27 +64,24 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eve, Eden
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall import diagnostics
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
from icefall.utils import (
|
LRSchedulerType = Union[
|
||||||
AttributeDict,
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
MetricsTracker,
|
]
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -168,7 +164,7 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=5000,
|
default=5000,
|
||||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
help="""Number of steps that affects how rapidly the learning rate decreases.
|
||||||
We suggest not to change this."""
|
We suggest not to change this.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -176,7 +172,7 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=6,
|
default=6,
|
||||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -335,7 +331,7 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for joiner
|
# parameters for joiner
|
||||||
"joiner_dim": 512,
|
"joiner_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -510,7 +506,7 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
warmup: float = 1.0
|
warmup: float = 1.0,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
@ -557,18 +553,24 @@ def compute_loss(
|
|||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (0.0 if warmup < 1.0 else
|
pruned_loss_scale = (
|
||||||
(0.1 if warmup > 1.0 and warmup < 2.0 else
|
0.0
|
||||||
1.0))
|
if warmup < 1.0
|
||||||
loss = (params.simple_loss_scale * simple_loss +
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||||
pruned_loss_scale * pruned_loss)
|
)
|
||||||
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -675,7 +677,7 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=(params.batch_idx_train / params.model_warm_step)
|
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -691,8 +693,10 @@ def train_one_epoch(
|
|||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (params.batch_idx_train > 0
|
if (
|
||||||
and params.batch_idx_train % params.save_every_n == 0):
|
params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
|
):
|
||||||
params.cur_batch_idx = batch_idx
|
params.cur_batch_idx = batch_idx
|
||||||
save_checkpoint_with_global_batch_idx(
|
save_checkpoint_with_global_batch_idx(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
@ -723,7 +727,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
@ -813,18 +817,19 @@ def run(rank, world_size, args):
|
|||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
optimizer = Eve(
|
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||||
model.parameters(),
|
|
||||||
lr=params.initial_lr)
|
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
logging.info("Loading optimizer state dict")
|
logging.info("Loading optimizer state dict")
|
||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None:
|
if (
|
||||||
|
checkpoints
|
||||||
|
and "scheduler" in checkpoints
|
||||||
|
and checkpoints["scheduler"] is not None
|
||||||
|
):
|
||||||
logging.info("Loading scheduler state dict")
|
logging.info("Loading scheduler state dict")
|
||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
@ -834,7 +839,6 @@ def run(rank, world_size, args):
|
|||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
@ -889,7 +893,6 @@ def run(rank, world_size, args):
|
|||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
@ -956,7 +959,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup = 0.0
|
warmup=0.0,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -486,7 +486,9 @@ def modified_beam_search(
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
|
topk_hyp_indexes = torch.div(
|
||||||
|
topk_indexes, vocab_size, rounding_mode="trunc"
|
||||||
|
)
|
||||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
@ -29,11 +29,11 @@ from torch.cuda.amp import GradScaler
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
# use duck typing for LRScheduler since we have different possibilities, see
|
# use duck typing for LRScheduler since we have different possibilities, see
|
||||||
# our class LRScheduler.
|
# our class LRScheduler.
|
||||||
LRSchedulerType = object
|
LRSchedulerType = object
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
filename: Path,
|
filename: Path,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user