Code style check for librispeech pruned transducer stateless2 (#308)

This commit is contained in:
Mingshuang Luo 2022-04-11 22:15:18 +08:00 committed by GitHub
parent 8cb727e24a
commit 93c60a9d30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 484 additions and 296 deletions

View File

@ -7,6 +7,8 @@ per-file-ignores =
egs/librispeech/ASR/*/conformer.py: E501, egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501, egs/aishell/ASR/*/conformer.py: E501,
egs/tedlium3/ASR/*/conformer.py: E501, egs/tedlium3/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
# invalid escape sequence (cause by tex formular), W605 # invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605 icefall/utils.py: E501, W605

View File

@ -93,7 +93,9 @@ def fast_beam_search(
) )
# fmt: on # fmt: on
logits = model.joiner( logits = model.joiner(
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
) )
logits = logits.squeeze(1).squeeze(1) logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) log_probs = logits.log_softmax(dim=-1)
@ -140,7 +142,6 @@ def greedy_search(
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0
hyp = [blank_id] * context_size hyp = [blank_id] * context_size
@ -163,9 +164,9 @@ def greedy_search(
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on # fmt: on
logits = model.joiner(current_encoder_out, logits = model.joiner(
decoder_out.unsqueeze(1), current_encoder_out, decoder_out.unsqueeze(1), project_input=False
project_input=False) )
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()
@ -228,8 +229,9 @@ def greedy_search_batch(
for t in range(T): for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1), logits = model.joiner(
project_input=False) current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits'shape (batch_size, 1, 1, vocab_size) # logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
@ -466,7 +468,6 @@ def modified_beam_search(
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim) # decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below. # as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select( current_encoder_out = torch.index_select(
@ -720,7 +721,7 @@ def beam_search(
logits = model.joiner( logits = model.joiner(
current_encoder_out, current_encoder_out,
decoder_out.unsqueeze(1), decoder_out.unsqueeze(1),
project_input=False project_input=False,
) )
# TODO(fangjun): Scale the blank posterior # TODO(fangjun): Scale the blank posterior

View File

@ -16,13 +16,20 @@
# limitations under the License. # limitations under the License.
import copy import copy
from encoder_interface import EncoderInterface
import math import math
import warnings import warnings
from typing import Optional, Tuple, Sequence from typing import Optional, Tuple
from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
import torch import torch
from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
)
from torch import Tensor, nn from torch import Tensor, nn
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
@ -42,6 +49,7 @@ class Conformer(EncoderInterface):
cnn_module_kernel (int): Kernel size of convolution module cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend. vgg_frontend (bool): whether to use vgg frontend.
""" """
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
@ -80,9 +88,8 @@ class Conformer(EncoderInterface):
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -112,8 +119,9 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask, x = self.encoder(
warmup=warmup) # (T, N, C) x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -176,18 +184,15 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_final = BasicNorm(d_model) self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median). # try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer(channel_dim=-1, self.balancer = ActivationBalancer(
min_positive=0.45, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
max_positive=0.55, )
max_abs=6.0)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
@ -220,14 +225,17 @@ class ConformerEncoderLayer(nn.Module):
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it. # completely bypass it.
if self.training: if self.training:
alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1 alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else: else:
alpha = 1.0 alpha = 1.0
# macaron style feed forward module # macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))
# multi-headed self-attention module # multi-headed self-attention module
src_att = self.self_attn( src_att = self.self_attn(
src, src,
@ -248,7 +256,7 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
if alpha != 1.0: if alpha != 1.0:
src = alpha * src + (1-alpha) * src_orig src = alpha * src + (1 - alpha) * src_orig
return src return src
@ -275,14 +283,13 @@ class ConformerEncoder(nn.Module):
) )
self.num_layers = num_layers self.num_layers = num_layers
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0 warmup: float = 1.0,
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -302,8 +309,6 @@ class ConformerEncoder(nn.Module):
""" """
output = src output = src
num_layers = len(self.layers)
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
output = mod( output = mod(
output, output,
@ -428,7 +433,9 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.25
)
# linear transformation for positional encoding. # linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
@ -621,7 +628,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -653,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:_end] _b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b) q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
_start = embed_dim _start = embed_dim
@ -672,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:] _b = _b[_start:]
v = nn.functional.linear(value, _w, _b) v = nn.functional.linear(value, _w, _b)
if attn_mask is not None: if attn_mask is not None:
assert ( assert (
attn_mask.dtype == torch.float32 attn_mask.dtype == torch.float32
@ -864,9 +871,9 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto # it will be in a better position to start learning something, i.e. to latch onto
# the correct range. # the correct range.
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, self.deriv_balancer1 = ActivationBalancer(
min_positive=0.05, channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
max_positive=1.0) )
self.depthwise_conv = ScaledConv1d( self.depthwise_conv = ScaledConv1d(
channels, channels,
@ -878,9 +885,9 @@ class ConvolutionModule(nn.Module):
bias=bias, bias=bias,
) )
self.deriv_balancer2 = ActivationBalancer(channel_dim=1, self.deriv_balancer2 = ActivationBalancer(
min_positive=0.05, channel_dim=1, min_positive=0.05, max_positive=1.0
max_positive=1.0) )
self.activation = DoubleSwish() self.activation = DoubleSwish()
@ -891,7 +898,7 @@ class ConvolutionModule(nn.Module):
stride=1, stride=1,
padding=0, padding=0,
bias=bias, bias=bias,
initial_scale=0.25 initial_scale=0.25,
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -924,7 +931,6 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1) return x.permute(2, 0, 1)
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length). """Convolutional 2D subsampling (to 1/4 length).
@ -936,11 +942,14 @@ class Conv2dSubsampling(nn.Module):
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
""" """
def __init__(self, in_channels: int, def __init__(
out_channels: int, self,
layer1_channels: int = 8, in_channels: int,
layer2_channels: int = 32, out_channels: int,
layer3_channels: int = 128) -> None: layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
) -> None:
""" """
Args: Args:
in_channels: in_channels:
@ -958,34 +967,41 @@ class Conv2dSubsampling(nn.Module):
self.conv = nn.Sequential( self.conv = nn.Sequential(
ScaledConv2d( ScaledConv2d(
in_channels=1, out_channels=layer1_channels, in_channels=1,
kernel_size=3, padding=1, out_channels=layer1_channels,
kernel_size=3,
padding=1,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),
ScaledConv2d( ScaledConv2d(
in_channels=layer1_channels, out_channels=layer2_channels, in_channels=layer1_channels,
kernel_size=3, stride=2, out_channels=layer2_channels,
kernel_size=3,
stride=2,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),
ScaledConv2d( ScaledConv2d(
in_channels=layer2_channels, out_channels=layer3_channels, in_channels=layer2_channels,
kernel_size=3, stride=2, out_channels=layer3_channels,
kernel_size=3,
stride=2,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),
) )
self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) self.out = ScaledLinear(
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out` # set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not # itself has learned scale, so the extra degree of freedom is not
# needed. # needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False) self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero. # constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(channel_dim=-1, self.out_balancer = ActivationBalancer(
min_positive=0.45, channel_dim=-1, min_positive=0.45, max_positive=0.55
max_positive=0.55) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.
@ -1009,13 +1025,14 @@ class Conv2dSubsampling(nn.Module):
return x return x
if __name__ == "__main__":
if __name__ == '__main__':
feature_dim = 50 feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4) c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
f = c(torch.randn(batch_size, seq_len, feature_dim), f = c(
torch.full((batch_size,), seq_len, dtype=torch.int64), torch.randn(batch_size, seq_len, feature_dim),
warmup=0.5) torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)

View File

@ -17,9 +17,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from scaling import ScaledConv1d, ScaledEmbedding
from typing import Optional
from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding
class Decoder(nn.Module): class Decoder(nn.Module):

View File

@ -16,15 +16,17 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from scaling import ScaledLinear from scaling import ScaledLinear
class Joiner(nn.Module): class Joiner(nn.Module):
def __init__(self, def __init__(
encoder_dim: int, self,
decoder_dim: int, encoder_dim: int,
joiner_dim: int, decoder_dim: int,
vocab_size: int): joiner_dim: int,
vocab_size: int,
):
super().__init__() super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
@ -32,8 +34,10 @@ class Joiner(nn.Module):
self.output_linear = ScaledLinear(joiner_dim, vocab_size) self.output_linear = ScaledLinear(joiner_dim, vocab_size)
def forward( def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, self,
project_input: bool = True encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -52,7 +56,9 @@ class Joiner(nn.Module):
assert encoder_out.shape[:-1] == decoder_out.shape[:-1] assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
else: else:
logit = encoder_out + decoder_out logit = encoder_out + decoder_out

View File

@ -37,7 +37,7 @@ class Transducer(nn.Module):
encoder_dim: int, encoder_dim: int,
decoder_dim: int, decoder_dim: int,
joiner_dim: int, joiner_dim: int,
vocab_size: int vocab_size: int,
): ):
""" """
Args: Args:
@ -48,11 +48,11 @@ class Transducer(nn.Module):
`logit_lens` of shape (N,). `logit_lens` of shape (N,).
decoder: decoder:
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim). It should contain is (N, U) and its output shape is (N, U, decoder_dim).
one attribute: `blank_id`. It should contain one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
output shape is (N, T, U, vocab_size). Note that its output contains Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
""" """
super().__init__() super().__init__()
@ -63,8 +63,9 @@ class Transducer(nn.Module):
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, self.simple_am_proj = ScaledLinear(
initial_speed=0.5) encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
def forward( def forward(
@ -141,8 +142,8 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens
lm=self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am=self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
@ -170,15 +171,14 @@ class Transducer(nn.Module):
am_pruned, lm_pruned = k2.do_rnnt_pruning( am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out), am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out), lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges ranges=ranges,
) )
# logits : [B, T, prune_range, vocab_size] # logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, logits = self.joiner(am_pruned, lm_pruned, project_input=False)
project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(

View File

@ -15,11 +15,9 @@
# limitations under the License. # limitations under the License.
import random from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
@ -59,24 +57,41 @@ class Eve(Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, def __init__(
weight_decay=1e-3, target_rms=0.1): self,
params,
lr=1e-3,
betas=(0.9, 0.98),
eps=1e-8,
weight_decay=1e-3,
target_rms=0.1,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0: if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1: if not 0 <= weight_decay <= 0.1:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0: if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms)) raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(lr=lr, betas=betas, eps=eps, defaults = dict(
weight_decay=weight_decay, lr=lr,
target_rms=target_rms) betas=betas,
eps=eps,
weight_decay=weight_decay,
target_rms=target_rms,
)
super(Eve, self).__init__(params, defaults) super(Eve, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
@ -96,83 +111,98 @@ class Eve(Optimizer):
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
# Perform optimization step # Perform optimization step
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients') raise RuntimeError(
"AdamW does not support sparse gradients"
)
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state["step"] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group['betas'] beta1, beta2 = group["betas"]
state['step'] += 1 state["step"] += 1
bias_correction1 = 1 - beta1 ** state['step'] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state["step"]
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps']) denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
group["eps"]
)
step_size = group['lr'] / bias_correction1 step_size = group["lr"] / bias_correction1
target_rms = group['target_rms'] target_rms = group["target_rms"]
weight_decay = group['weight_decay'] weight_decay = group["weight_decay"]
delta = exp_avg / denom
if p.numel() > 1: if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors" # avoid applying this weight-decay on "scaling factors"
# (which are scalar). # (which are scalar).
is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5))) is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms)) p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size) p.addcdiv_(exp_avg, denom, value=-step_size)
return loss return loss
class LRScheduler(object): class LRScheduler(object):
""" """
Base-class for learning rate schedulers where the learning-rate depends on both the Base-class for learning rate schedulers where the learning-rate depends on both the
batch and the epoch. batch and the epoch.
""" """
def __init__(self, optimizer: Optimizer, verbose: bool = False): def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer # Attach optimizer
if not isinstance(optimizer, Optimizer): if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format( raise TypeError(
type(optimizer).__name__)) "{} is not an Optimizer".format(type(optimizer).__name__)
)
self.optimizer = optimizer self.optimizer = optimizer
self.verbose = verbose self.verbose = verbose
for group in optimizer.param_groups: for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr']) group.setdefault("initial_lr", group["lr"])
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] self.base_lrs = [
group["initial_lr"] for group in optimizer.param_groups
]
self.epoch = 0 self.epoch = 0
self.batch = 0 self.batch = 0
def state_dict(self): def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`. """Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which It contains an entry for every variable in self.__dict__ which
is not the optimizer. is not the optimizer.
""" """
return {'base_lrs': self.base_lrs, return {
'epoch': self.epoch, "base_lrs": self.base_lrs,
'batch': self.batch} "epoch": self.epoch,
"batch": self.batch,
}
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
"""Loads the schedulers state. """Loads the schedulers state.
@ -184,8 +214,7 @@ class LRScheduler(object):
self.__dict__.update(state_dict) self.__dict__.update(state_dict)
def get_last_lr(self) -> List[float]: def get_last_lr(self) -> List[float]:
""" Return last computed learning rate by current scheduler. Will be a list of float. """Return last computed learning rate by current scheduler. Will be a list of float."""
"""
return self._last_lr return self._last_lr
def get_lr(self): def get_lr(self):
@ -194,7 +223,6 @@ class LRScheduler(object):
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
raise NotImplementedError raise NotImplementedError
def step_batch(self, batch: Optional[int] = None) -> None: def step_batch(self, batch: Optional[int] = None) -> None:
# Step the batch index, or just set it. If `batch` is specified, it # Step the batch index, or just set it. If `batch` is specified, it
# must be the batch index from the start of training, i.e. summed over # must be the batch index from the start of training, i.e. summed over
@ -217,24 +245,23 @@ class LRScheduler(object):
self.epoch = self.epoch + 1 self.epoch = self.epoch + 1
self._set_lrs() self._set_lrs()
def _set_lrs(self): def _set_lrs(self):
values = self.get_lr() values = self.get_lr()
assert len(values) == len(self.optimizer.param_groups) assert len(values) == len(self.optimizer.param_groups)
for i, data in enumerate(zip(self.optimizer.param_groups, values)): for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data param_group, lr = data
param_group['lr'] = lr param_group["lr"] = lr
self.print_lr(self.verbose, i, lr) self.print_lr(self.verbose, i, lr)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups] self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def print_lr(self, is_verbose, group, lr): def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate. """Display the current learning rate."""
"""
if is_verbose: if is_verbose:
print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate' print(
f' of group {group} to {lr:.4e}.') f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
class Eden(LRScheduler): class Eden(LRScheduler):
@ -254,18 +281,27 @@ class Eden(LRScheduler):
20 to 40 epochs, but may need smaller number if dataset is huge 20 to 40 epochs, but may need smaller number if dataset is huge
and you will do few epochs. and you will do few epochs.
""" """
def __init__(self, optimizer: Optimizer,
lr_batches: Union[int, float], def __init__(
lr_epochs: Union[int, float], self,
verbose: bool = False): optimizer: Optimizer,
lr_batches: Union[int, float],
lr_epochs: Union[int, float],
verbose: bool = False,
):
super(Eden, self).__init__(optimizer, verbose) super(Eden, self).__init__(optimizer, verbose)
self.lr_batches = lr_batches self.lr_batches = lr_batches
self.lr_epochs = lr_epochs self.lr_epochs = lr_epochs
def get_lr(self): def get_lr(self):
factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 * factor = (
(((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25)) (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
return [ x * factor for x in self.base_lrs ] ) ** -0.25 * (
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
** -0.25
)
return [x * factor for x in self.base_lrs]
def _test_eden(): def _test_eden():
m = torch.nn.Linear(100, 100) m = torch.nn.Linear(100, 100)
@ -290,5 +326,6 @@ def _test_eden():
print("last lr = ", scheduler.get_last_lr()) print("last lr = ", scheduler.get_last_lr())
print("state dict = ", scheduler.state_dict()) print("state dict = ", scheduler.state_dict())
if __name__ == '__main__':
if __name__ == "__main__":
_test_eden() _test_eden()

View File

@ -15,54 +15,86 @@
# limitations under the License. # limitations under the License.
import collections
from itertools import repeat
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Tuple, Optional
def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
class ActivationBalancerFunction(torch.autograd.Function): class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, def forward(
channel_dim: int, ctx,
min_positive: float, # e.g. 0.05 x: Tensor,
max_positive: float, # e.g. 0.95 channel_dim: int,
max_factor: float, # e.g. 0.01 min_positive: float, # e.g. 0.05
min_abs: float, # e.g. 0.2 max_positive: float, # e.g. 0.95
max_abs: float, # e.g. 100.0 max_factor: float, # e.g. 0.01
min_abs: float, # e.g. 0.2
max_abs: float, # e.g. 100.0
) -> Tensor: ) -> Tensor:
if x.requires_grad: if x.requires_grad:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
xgt0 = x > 0 xgt0 = x > 0
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) proportion_positive = torch.mean(
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) xgt0.to(x.dtype), dim=sum_dims, keepdim=True
if min_positive != 0.0 else 0.0) )
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) factor1 = (
if max_positive != 1.0 else 0.0) (min_positive - proportion_positive).relu()
* (max_factor / min_positive)
if min_positive != 0.0
else 0.0
)
factor2 = (
(proportion_positive - max_positive).relu()
* (max_factor / (max_positive - 1.0))
if max_positive != 1.0
else 0.0
)
factor = factor1 + factor2 factor = factor1 + factor2
if isinstance(factor, float): if isinstance(factor, float):
factor = torch.zeros_like(proportion_positive) factor = torch.zeros_like(proportion_positive)
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
below_threshold = (mean_abs < min_abs) below_threshold = mean_abs < min_abs
above_threshold = (mean_abs > max_abs) above_threshold = mean_abs > max_abs
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.save_for_backward(
factor, xgt0, below_threshold, above_threshold
)
ctx.max_factor = max_factor ctx.max_factor = max_factor
ctx.sum_dims = sum_dims ctx.sum_dims = sum_dims
return x return x
@staticmethod @staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None, None]:
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype dtype = x_grad.dtype
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * scale_factor = (
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) (below_threshold.to(dtype) - above_threshold.to(dtype))
* (xgt0.to(dtype) - 0.5)
* (ctx.max_factor * 2.0)
)
neg_delta_grad = x_grad.abs() * (factor + scale_factor) neg_delta_grad = x_grad.abs() * (factor + scale_factor)
return x_grad - neg_delta_grad, None, None, None, None, None, None return x_grad - neg_delta_grad, None, None, None, None, None, None
@ -95,29 +127,31 @@ class BasicNorm(torch.nn.Module):
learn_eps: if true, we learn epsilon; if false, we keep it learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value. at the initial value.
""" """
def __init__(self,
num_channels: int, def __init__(
channel_dim: int = -1, # CAUTION: see documentation. self,
eps: float = 0.25, num_channels: int,
learn_eps: bool = True) -> None: channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25,
learn_eps: bool = True,
) -> None:
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
if learn_eps: if learn_eps:
self.eps = nn.Parameter(torch.tensor(eps).log().detach()) self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else: else:
self.register_buffer('eps', torch.tensor(eps).log().detach()) self.register_buffer("eps", torch.tensor(eps).log().detach())
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + scales = (
self.eps.exp()) ** -0.5 torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
) ** -0.5
return x * scales return x * scales
class ScaledLinear(nn.Linear): class ScaledLinear(nn.Linear):
""" """
A modified version of nn.Linear where the parameters are scaled before A modified version of nn.Linear where the parameters are scaled before
@ -143,19 +177,25 @@ class ScaledLinear(nn.Linear):
Alternatively you can set it to more than 1 if you want it to Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0. initially train faster. Must be greater than 0.
""" """
def __init__(self, *args,
initial_scale: float = 1.0, def __init__(
initial_speed: float = 1.0, self,
**kwargs): *args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledLinear, self).__init__(*args, **kwargs) super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter("bias_scale", None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self, initial_speed: float): def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed std = 0.1 / initial_speed
@ -172,28 +212,33 @@ class ScaledLinear(nn.Linear):
return self.weight * self.weight_scale.exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return (None if self.bias is None else return None if self.bias is None else self.bias * self.bias_scale.exp()
self.bias * self.bias_scale.exp())
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(input, self.get_weight(), return torch.nn.functional.linear(
self.get_bias()) input, self.get_weight(), self.get_bias()
)
class ScaledConv1d(nn.Conv1d): class ScaledConv1d(nn.Conv1d):
# See docs for ScaledLinear # See docs for ScaledLinear
def __init__(self, *args, def __init__(
initial_scale: float = 1.0, self,
initial_speed: float = 1.0, *args,
**kwargs): initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledConv1d, self).__init__(*args, **kwargs) super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter("bias_scale", None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float): def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed std = 0.1 / initial_speed
@ -206,39 +251,58 @@ class ScaledConv1d(nn.Conv1d):
with torch.no_grad(): with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log() self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self): def get_weight(self):
return self.weight * self.weight_scale.exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return (None if self.bias is None else return None if self.bias is None else self.bias * self.bias_scale.exp()
self.bias * self.bias_scale.exp())
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional F = torch.nn.functional
if self.padding_mode != 'zeros': if self.padding_mode != "zeros":
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), return F.conv1d(
self.get_weight(), self.get_bias(), self.stride, F.pad(
_single(0), self.dilation, self.groups) input,
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, self._reversed_padding_repeated_twice,
self.padding, self.dilation, self.groups) mode=self.padding_mode,
),
self.get_weight(),
self.get_bias(),
self.stride,
_single(0),
self.dilation,
self.groups,
)
return F.conv1d(
input,
self.get_weight(),
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
class ScaledConv2d(nn.Conv2d): class ScaledConv2d(nn.Conv2d):
# See docs for ScaledLinear # See docs for ScaledLinear
def __init__(self, *args, def __init__(
initial_scale: float = 1.0, self,
initial_speed: float = 1.0, *args,
**kwargs): initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
super(ScaledConv2d, self).__init__(*args, **kwargs) super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter("bias_scale", None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float): def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed std = 0.1 / initial_speed
@ -251,29 +315,42 @@ class ScaledConv2d(nn.Conv2d):
with torch.no_grad(): with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log() self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self): def get_weight(self):
return self.weight * self.weight_scale.exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return (None if self.bias is None else return None if self.bias is None else self.bias * self.bias_scale.exp()
self.bias * self.bias_scale.exp())
def _conv_forward(self, input, weight): def _conv_forward(self, input, weight):
F = torch.nn.functional F = torch.nn.functional
if self.padding_mode != 'zeros': if self.padding_mode != "zeros":
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), return F.conv2d(
weight, self.get_bias(), self.stride, F.pad(
_pair(0), self.dilation, self.groups) input,
return F.conv2d(input, weight, self.get_bias(), self.stride, self._reversed_padding_repeated_twice,
self.padding, self.dilation, self.groups) mode=self.padding_mode,
),
weight,
self.get_bias(),
self.stride,
_pair(0),
self.dilation,
self.groups,
)
return F.conv2d(
input,
weight,
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.get_weight()) return self._conv_forward(input, self.get_weight())
class ActivationBalancer(torch.nn.Module): class ActivationBalancer(torch.nn.Module):
""" """
Modifies the backpropped derivatives of a function to try to encourage, for Modifies the backpropped derivatives of a function to try to encourage, for
@ -302,12 +379,16 @@ class ActivationBalancer(torch.nn.Module):
we allow, before we start to modify the derivatives to prevent we allow, before we start to modify the derivatives to prevent
this. this.
""" """
def __init__(self, channel_dim: int,
min_positive: float = 0.05, def __init__(
max_positive: float = 0.95, self,
max_factor: float = 0.01, channel_dim: int,
min_abs: float = 0.2, min_positive: float = 0.05,
max_abs: float = 100.0): max_positive: float = 0.95,
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0,
):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.min_positive = min_positive self.min_positive = min_positive
@ -317,10 +398,15 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return ActivationBalancerFunction.apply(x, self.channel_dim, return ActivationBalancerFunction.apply(
self.min_positive, self.max_positive, x,
self.max_factor, self.min_abs, self.channel_dim,
self.max_abs) self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function): class DoubleSwishFunction(torch.autograd.Function):
@ -338,6 +424,7 @@ class DoubleSwishFunction(torch.autograd.Function):
= double_swish(x) * (1-s(x)) + s(x) = double_swish(x) * (1-s(x)) + s(x)
... so we just need to remember s(x) but not x itself. ... so we just need to remember s(x) but not x itself.
""" """
@staticmethod @staticmethod
def forward(ctx, x: Tensor) -> Tensor: def forward(ctx, x: Tensor) -> Tensor:
x = x.detach() x = x.detach()
@ -349,18 +436,17 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor: def backward(ctx, y_grad: Tensor) -> Tensor:
s, y = ctx.saved_tensors s, y = ctx.saved_tensors
return (y * (1-s) + s) * y_grad return (y * (1 - s) + s) * y_grad
class DoubleSwish(torch.nn.Module): class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return double-swish activation function which is an approximation to Swish(Swish(x)), """Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1). that we approximate closely with x * sigmoid(x-1).
""" """
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module): class ScaledEmbedding(nn.Module):
r"""This is a modified version of nn.Embedding that introduces a learnable scale r"""This is a modified version of nn.Embedding that introduces a learnable scale
on the parameters. Note: due to how we initialize it, it's best used with on the parameters. Note: due to how we initialize it, it's best used with
@ -443,8 +529,13 @@ class ScaledEmbedding(nn.Module):
[-0.1655, 0.9897, 0.0635]]]) [-0.1655, 0.9897, 0.0635]]])
""" """
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', __constants__ = [
'scale_grad_by_freq', 'sparse'] "num_embeddings",
"embedding_dim",
"padding_idx",
"scale_grad_by_freq",
"sparse",
]
num_embeddings: int num_embeddings: int
embedding_dim: int embedding_dim: int
@ -453,33 +544,41 @@ class ScaledEmbedding(nn.Module):
weight: Tensor weight: Tensor
sparse: bool sparse: bool
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, def __init__(
scale_grad_by_freq: bool = False, self,
sparse: bool = False, num_embeddings: int,
initial_speed: float = 1.0) -> None: embedding_dim: int,
padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
initial_speed: float = 1.0,
) -> None:
super(ScaledEmbedding, self).__init__() super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
if padding_idx is not None: if padding_idx is not None:
if padding_idx > 0: if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
elif padding_idx < 0: elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq self.scale_grad_by_freq = scale_grad_by_freq
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.sparse = sparse self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters(initial_speed) self.reset_parameters(initial_speed)
def reset_parameters(self, initial_speed: float = 1.0) -> None: def reset_parameters(self, initial_speed: float = 1.0) -> None:
std = 0.1 / initial_speed std = 0.1 / initial_speed
nn.init.normal_(self.weight, std=std) nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
if self.padding_idx is not None: if self.padding_idx is not None:
with torch.no_grad(): with torch.no_grad():
@ -489,36 +588,53 @@ class ScaledEmbedding(nn.Module):
F = torch.nn.functional F = torch.nn.functional
scale = self.scale.exp() scale = self.scale.exp()
if input.numel() < self.num_embeddings: if input.numel() < self.num_embeddings:
return F.embedding( return (
input, self.weight, self.padding_idx, F.embedding(
None, 2.0, # None, 2.0 relate to normalization input,
self.scale_grad_by_freq, self.sparse) * scale self.weight,
self.padding_idx,
None,
2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq,
self.sparse,
)
* scale
)
else: else:
return F.embedding( return F.embedding(
input, self.weight * scale, self.padding_idx, input,
None, 2.0, # None, 2.0 relates to normalization self.weight * scale,
self.scale_grad_by_freq, self.sparse) self.padding_idx,
None,
2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq,
self.sparse,
)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}, scale={scale}' s = "{num_embeddings}, {embedding_dim}, scale={scale}"
if self.padding_idx is not None: if self.padding_idx is not None:
s += ', padding_idx={padding_idx}' s += ", padding_idx={padding_idx}"
if self.scale_grad_by_freq is not False: if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}' s += ", scale_grad_by_freq={scale_grad_by_freq}"
if self.sparse is not False: if self.sparse is not False:
s += ', sparse=True' s += ", sparse=True"
return s.format(**self.__dict__) return s.format(**self.__dict__)
def _test_activation_balancer_sign(): def _test_activation_balancer_sign():
channel_dim = 0
probs = torch.arange(0, 1, 0.01) probs = torch.arange(0, 1, 0.01)
N = 1000 N = 1000
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, m = ActivationBalancer(
max_factor=0.2, min_abs=0.0) channel_dim=0,
min_positive=0.05,
max_positive=0.95,
max_factor=0.2,
min_abs=0.0,
)
y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad = torch.sign(torch.randn(probs.numel(), N))
@ -528,17 +644,23 @@ def _test_activation_balancer_sign():
print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: y grad = ", y_grad)
print("_test_activation_balancer_sign: x grad = ", x.grad) print("_test_activation_balancer_sign: x grad = ", x.grad)
def _test_activation_balancer_magnitude(): def _test_activation_balancer_magnitude():
channel_dim = 0
magnitudes = torch.arange(0, 1, 0.01) magnitudes = torch.arange(0, 1, 0.01)
N = 1000 N = 1000
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
-1
)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = ActivationBalancer(channel_dim=0, m = ActivationBalancer(
min_positive=0.0, max_positive=1.0, channel_dim=0,
max_factor=0.2, min_positive=0.0,
min_abs=0.2, max_abs=0.8) max_positive=1.0,
max_factor=0.2,
min_abs=0.2,
max_abs=0.8,
)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
@ -558,8 +680,8 @@ def _test_basic_norm():
y = m(x) y = m(x)
assert y.shape == x.shape assert y.shape == x.shape
x_rms = (x**2).mean().sqrt() x_rms = (x ** 2).mean().sqrt()
y_rms = (y**2).mean().sqrt() y_rms = (y ** 2).mean().sqrt()
print("x rms = ", x_rms) print("x rms = ", x_rms)
print("y rms = ", y_rms) print("y rms = ", y_rms)
assert y_rms < x_rms assert y_rms < x_rms
@ -573,7 +695,7 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x) torch.autograd.gradcheck(m, x)
if __name__ == '__main__': if __name__ == "__main__":
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_basic_norm() _test_basic_norm()

View File

@ -45,16 +45,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import math
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import optim
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import optim # from .
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
@ -65,27 +64,24 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eve, Eden from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall import diagnostics from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import ( LRSchedulerType = Union[
AttributeDict, torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
MetricsTracker, ]
setup_logger,
str2bool,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -168,7 +164,7 @@ def get_parser():
type=float, type=float,
default=5000, default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases. help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this.""" We suggest not to change this.""",
) )
parser.add_argument( parser.add_argument(
@ -176,7 +172,7 @@ def get_parser():
type=float, type=float,
default=6, default=6,
help="""Number of epochs that affects how rapidly the learning rate decreases. help="""Number of epochs that affects how rapidly the learning rate decreases.
""" """,
) )
parser.add_argument( parser.add_argument(
@ -335,7 +331,7 @@ def get_params() -> AttributeDict:
# parameters for joiner # parameters for joiner
"joiner_dim": 512, "joiner_dim": 512,
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate "model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -510,7 +506,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
warmup: float = 1.0 warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -557,18 +553,24 @@ def compute_loss(
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = (0.0 if warmup < 1.0 else pruned_loss_scale = (
(0.1 if warmup > 1.0 and warmup < 2.0 else 0.0
1.0)) if warmup < 1.0
loss = (params.simple_loss_scale * simple_loss + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
pruned_loss_scale * pruned_loss) )
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -675,7 +677,7 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step) warmup=(params.batch_idx_train / params.model_warm_step),
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -691,8 +693,10 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if (params.batch_idx_train > 0 if (
and params.batch_idx_train % params.save_every_n == 0): params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
@ -723,7 +727,7 @@ def train_one_epoch(
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train "train/learning_rate", cur_lr, params.batch_idx_train
) )
loss_info.write_summary( loss_info.write_summary(
@ -813,18 +817,19 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
model.device = device model.device = device
optimizer = Eve( optimizer = Eve(model.parameters(), lr=params.initial_lr)
model.parameters(),
lr=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict") logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None: if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict") logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
@ -834,7 +839,6 @@ def run(rank, world_size, args):
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
@ -889,7 +893,6 @@ def run(rank, world_size, args):
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = scheduler.get_last_lr()[0]
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
@ -956,7 +959,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup = 0.0 warmup=0.0,
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -486,7 +486,9 @@ def modified_beam_search(
for i in range(batch_size): for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc") topk_hyp_indexes = torch.div(
topk_indexes, vocab_size, rounding_mode="trunc"
)
topk_hyp_indexes = topk_hyp_indexes.tolist() topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist()

View File

@ -29,11 +29,11 @@ from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
# use duck typing for LRScheduler since we have different possibilities, see # use duck typing for LRScheduler since we have different possibilities, see
# our class LRScheduler. # our class LRScheduler.
LRSchedulerType = object LRSchedulerType = object
def save_checkpoint( def save_checkpoint(
filename: Path, filename: Path,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],