Code style check for librispeech pruned transducer stateless2 (#308)

This commit is contained in:
Mingshuang Luo 2022-04-11 22:15:18 +08:00 committed by GitHub
parent 8cb727e24a
commit 93c60a9d30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 484 additions and 296 deletions

View File

@ -7,6 +7,8 @@ per-file-ignores =
egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501,
egs/tedlium3/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605

View File

@ -93,7 +93,9 @@ def fast_beam_search(
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
@ -140,7 +142,6 @@ def greedy_search(
encoder_out = model.joiner.encoder_proj(encoder_out)
T = encoder_out.size(1)
t = 0
hyp = [blank_id] * context_size
@ -163,9 +164,9 @@ def greedy_search(
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(current_encoder_out,
decoder_out.unsqueeze(1),
project_input=False)
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
@ -228,8 +229,9 @@ def greedy_search_batch(
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1),
project_input=False)
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
@ -466,7 +468,6 @@ def modified_beam_search(
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
@ -720,7 +721,7 @@ def beam_search(
logits = model.joiner(
current_encoder_out,
decoder_out.unsqueeze(1),
project_input=False
project_input=False,
)
# TODO(fangjun): Scale the blank posterior

View File

@ -16,13 +16,20 @@
# limitations under the License.
import copy
from encoder_interface import EncoderInterface
import math
import warnings
from typing import Optional, Tuple, Sequence
from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
from typing import Optional, Tuple
import torch
from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
)
from torch import Tensor, nn
from icefall.utils import make_pad_mask
@ -42,6 +49,7 @@ class Conformer(EncoderInterface):
cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
self,
num_features: int,
@ -80,7 +88,6 @@ class Conformer(EncoderInterface):
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -112,8 +119,9 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
warmup=warmup) # (T, N, C)
x = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -176,18 +184,15 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0)
self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: Tensor,
@ -220,14 +225,17 @@ class ConformerEncoderLayer(nn.Module):
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
if self.training:
alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1
alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else:
alpha = 1.0
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
# multi-headed self-attention module
src_att = self.self_attn(
src,
@ -248,7 +256,7 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1-alpha) * src_orig
src = alpha * src + (1 - alpha) * src_orig
return src
@ -275,14 +283,13 @@ class ConformerEncoder(nn.Module):
)
self.num_layers = num_layers
def forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0
warmup: float = 1.0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -302,8 +309,6 @@ class ConformerEncoder(nn.Module):
"""
output = src
num_layers = len(self.layers)
for i, mod in enumerate(self.layers):
output = mod(
output,
@ -428,7 +433,9 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.25
)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
@ -621,7 +628,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
@ -653,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
@ -672,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
@ -864,9 +871,9 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto
# the correct range.
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
min_positive=0.05,
max_positive=1.0)
self.deriv_balancer1 = ActivationBalancer(
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
self.depthwise_conv = ScaledConv1d(
channels,
@ -878,9 +885,9 @@ class ConvolutionModule(nn.Module):
bias=bias,
)
self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
min_positive=0.05,
max_positive=1.0)
self.deriv_balancer2 = ActivationBalancer(
channel_dim=1, min_positive=0.05, max_positive=1.0
)
self.activation = DoubleSwish()
@ -891,7 +898,7 @@ class ConvolutionModule(nn.Module):
stride=1,
padding=0,
bias=bias,
initial_scale=0.25
initial_scale=0.25,
)
def forward(self, x: Tensor) -> Tensor:
@ -924,7 +931,6 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1)
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
@ -936,11 +942,14 @@ class Conv2dSubsampling(nn.Module):
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, in_channels: int,
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128) -> None:
layer3_channels: int = 128,
) -> None:
"""
Args:
in_channels:
@ -958,34 +967,41 @@ class Conv2dSubsampling(nn.Module):
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1, out_channels=layer1_channels,
kernel_size=3, padding=1,
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=1,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels, out_channels=layer2_channels,
kernel_size=3, stride=2,
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer2_channels, out_channels=layer3_channels,
kernel_size=3, stride=2,
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
self.out = ScaledLinear(
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -1009,13 +1025,14 @@ class Conv2dSubsampling(nn.Module):
return x
if __name__ == '__main__':
if __name__ == "__main__":
feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(torch.randn(batch_size, seq_len, feature_dim),
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5)
warmup=0.5,
)

View File

@ -17,9 +17,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional
from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding
from scaling import ScaledConv1d, ScaledEmbedding
class Decoder(nn.Module):

View File

@ -16,15 +16,17 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import ScaledLinear
class Joiner(nn.Module):
def __init__(self,
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int):
vocab_size: int,
):
super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
@ -32,8 +34,10 @@ class Joiner(nn.Module):
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor,
project_input: bool = True
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
@ -52,7 +56,9 @@ class Joiner(nn.Module):
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
else:
logit = encoder_out + decoder_out

View File

@ -37,7 +37,7 @@ class Transducer(nn.Module):
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int
vocab_size: int,
):
"""
Args:
@ -48,11 +48,11 @@ class Transducer(nn.Module):
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim). It should contain
one attribute: `blank_id`.
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its
output shape is (N, T, U, vocab_size). Note that its output contains
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
@ -63,8 +63,9 @@ class Transducer(nn.Module):
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size,
initial_speed=0.5)
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
def forward(
@ -141,8 +142,8 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm=self.simple_lm_proj(decoder_out)
am=self.simple_am_proj(encoder_out)
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
@ -170,15 +171,14 @@ class Transducer(nn.Module):
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned,
project_input=False)
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(

View File

@ -15,11 +15,9 @@
# limitations under the License.
import random
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
import torch
from torch import Tensor
from torch.optim import Optimizer
@ -59,24 +57,41 @@ class Eve(Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8,
weight_decay=1e-3, target_rms=0.1):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.98),
eps=1e-8,
weight_decay=1e-3,
target_rms=0.1,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(lr=lr, betas=betas, eps=eps,
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
target_rms=target_rms)
target_rms=target_rms,
)
super(Eve, self).__init__(params, defaults)
def __setstate__(self, state):
@ -96,83 +111,98 @@ class Eve(Optimizer):
loss = closure()
for group in self.param_groups:
for p in group['params']:
for p in group["params"]:
if p.grad is None:
continue
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
raise RuntimeError(
"AdamW does not support sparse gradients"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state["step"] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group['betas']
beta1, beta2 = group["betas"]
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
state["step"] += 1
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps'])
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
group["eps"]
)
step_size = group['lr'] / bias_correction1
target_rms = group['target_rms']
weight_decay = group['weight_decay']
delta = exp_avg / denom
step_size = group["lr"] / bias_correction1
target_rms = group["target_rms"]
weight_decay = group["weight_decay"]
if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors"
# (which are scalar).
is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5)))
is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss
class LRScheduler(object):
"""
Base-class for learning rate schedulers where the learning-rate depends on both the
batch and the epoch.
"""
def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
raise TypeError(
"{} is not an Optimizer".format(type(optimizer).__name__)
)
self.optimizer = optimizer
self.verbose = verbose
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
group.setdefault("initial_lr", group["lr"])
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
self.base_lrs = [
group["initial_lr"] for group in optimizer.param_groups
]
self.epoch = 0
self.batch = 0
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {'base_lrs': self.base_lrs,
'epoch': self.epoch,
'batch': self.batch}
return {
"base_lrs": self.base_lrs,
"epoch": self.epoch,
"batch": self.batch,
}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
@ -184,8 +214,7 @@ class LRScheduler(object):
self.__dict__.update(state_dict)
def get_last_lr(self) -> List[float]:
""" Return last computed learning rate by current scheduler. Will be a list of float.
"""
"""Return last computed learning rate by current scheduler. Will be a list of float."""
return self._last_lr
def get_lr(self):
@ -194,7 +223,6 @@ class LRScheduler(object):
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
raise NotImplementedError
def step_batch(self, batch: Optional[int] = None) -> None:
# Step the batch index, or just set it. If `batch` is specified, it
# must be the batch index from the start of training, i.e. summed over
@ -217,24 +245,23 @@ class LRScheduler(object):
self.epoch = self.epoch + 1
self._set_lrs()
def _set_lrs(self):
values = self.get_lr()
assert len(values) == len(self.optimizer.param_groups)
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
param_group["lr"] = lr
self.print_lr(self.verbose, i, lr)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate.
"""
"""Display the current learning rate."""
if is_verbose:
print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate'
f' of group {group} to {lr:.4e}.')
print(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
class Eden(LRScheduler):
@ -254,18 +281,27 @@ class Eden(LRScheduler):
20 to 40 epochs, but may need smaller number if dataset is huge
and you will do few epochs.
"""
def __init__(self, optimizer: Optimizer,
def __init__(
self,
optimizer: Optimizer,
lr_batches: Union[int, float],
lr_epochs: Union[int, float],
verbose: bool = False):
verbose: bool = False,
):
super(Eden, self).__init__(optimizer, verbose)
self.lr_batches = lr_batches
self.lr_epochs = lr_epochs
def get_lr(self):
factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 *
(((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25))
return [ x * factor for x in self.base_lrs ]
factor = (
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
) ** -0.25 * (
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
** -0.25
)
return [x * factor for x in self.base_lrs]
def _test_eden():
m = torch.nn.Linear(100, 100)
@ -290,5 +326,6 @@ def _test_eden():
print("last lr = ", scheduler.get_last_lr())
print("state dict = ", scheduler.state_dict())
if __name__ == '__main__':
if __name__ == "__main__":
_test_eden()

View File

@ -15,18 +15,33 @@
# limitations under the License.
import collections
from itertools import repeat
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple, Optional
def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
return x
return tuple(repeat(x, n))
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor,
def forward(
ctx,
x: Tensor,
channel_dim: int,
min_positive: float, # e.g. 0.05
max_positive: float, # e.g. 0.95
@ -39,30 +54,47 @@ class ActivationBalancerFunction(torch.autograd.Function):
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
xgt0 = x > 0
proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True)
factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive)
if min_positive != 0.0 else 0.0)
factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0))
if max_positive != 1.0 else 0.0)
proportion_positive = torch.mean(
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
)
factor1 = (
(min_positive - proportion_positive).relu()
* (max_factor / min_positive)
if min_positive != 0.0
else 0.0
)
factor2 = (
(proportion_positive - max_positive).relu()
* (max_factor / (max_positive - 1.0))
if max_positive != 1.0
else 0.0
)
factor = factor1 + factor2
if isinstance(factor, float):
factor = torch.zeros_like(proportion_positive)
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
below_threshold = (mean_abs < min_abs)
above_threshold = (mean_abs > max_abs)
below_threshold = mean_abs < min_abs
above_threshold = mean_abs > max_abs
ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold)
ctx.save_for_backward(
factor, xgt0, below_threshold, above_threshold
)
ctx.max_factor = max_factor
ctx.sum_dims = sum_dims
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]:
def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None, None]:
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
dtype = x_grad.dtype
scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) *
(xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0))
scale_factor = (
(below_threshold.to(dtype) - above_threshold.to(dtype))
* (xgt0.to(dtype) - 0.5)
* (ctx.max_factor * 2.0)
)
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
return x_grad - neg_delta_grad, None, None, None, None, None, None
@ -95,29 +127,31 @@ class BasicNorm(torch.nn.Module):
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
"""
def __init__(self,
def __init__(
self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25,
learn_eps: bool = True) -> None:
learn_eps: bool = True,
) -> None:
super(BasicNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
if learn_eps:
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else:
self.register_buffer('eps', torch.tensor(eps).log().detach())
self.register_buffer("eps", torch.tensor(eps).log().detach())
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
self.eps.exp()) ** -0.5
scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
) ** -0.5
return x * scales
class ScaledLinear(nn.Linear):
"""
A modified version of nn.Linear where the parameters are scaled before
@ -143,19 +177,25 @@ class ScaledLinear(nn.Linear):
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
"""
def __init__(self, *args,
def __init__(
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
**kwargs
):
super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self.register_parameter("bias_scale", None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
@ -172,28 +212,33 @@ class ScaledLinear(nn.Linear):
return self.weight * self.weight_scale.exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * self.bias_scale.exp())
return None if self.bias is None else self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(input, self.get_weight(),
self.get_bias())
return torch.nn.functional.linear(
input, self.get_weight(), self.get_bias()
)
class ScaledConv1d(nn.Conv1d):
# See docs for ScaledLinear
def __init__(self, *args,
def __init__(
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
**kwargs
):
super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
@ -206,39 +251,58 @@ class ScaledConv1d(nn.Conv1d):
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self):
return self.weight * self.weight_scale.exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * self.bias_scale.exp())
return None if self.bias is None else self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
self.get_weight(), self.get_bias(), self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
self.padding, self.dilation, self.groups)
if self.padding_mode != "zeros":
return F.conv1d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
self.get_weight(),
self.get_bias(),
self.stride,
_single(0),
self.dilation,
self.groups,
)
return F.conv1d(
input,
self.get_weight(),
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
class ScaledConv2d(nn.Conv2d):
# See docs for ScaledLinear
def __init__(self, *args,
def __init__(
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
**kwargs
):
super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
self.register_parameter("bias_scale", None)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
@ -251,29 +315,42 @@ class ScaledConv2d(nn.Conv2d):
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
def get_weight(self):
return self.weight * self.weight_scale.exp()
def get_bias(self):
return (None if self.bias is None else
self.bias * self.bias_scale.exp())
return None if self.bias is None else self.bias * self.bias_scale.exp()
def _conv_forward(self, input, weight):
F = torch.nn.functional
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.get_bias(), self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.get_bias(), self.stride,
self.padding, self.dilation, self.groups)
if self.padding_mode != "zeros":
return F.conv2d(
F.pad(
input,
self._reversed_padding_repeated_twice,
mode=self.padding_mode,
),
weight,
self.get_bias(),
self.stride,
_pair(0),
self.dilation,
self.groups,
)
return F.conv2d(
input,
weight,
self.get_bias(),
self.stride,
self.padding,
self.dilation,
self.groups,
)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.get_weight())
class ActivationBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
@ -302,12 +379,16 @@ class ActivationBalancer(torch.nn.Module):
we allow, before we start to modify the derivatives to prevent
this.
"""
def __init__(self, channel_dim: int,
def __init__(
self,
channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0):
max_abs: float = 100.0,
):
super(ActivationBalancer, self).__init__()
self.channel_dim = channel_dim
self.min_positive = min_positive
@ -317,10 +398,15 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor:
return ActivationBalancerFunction.apply(x, self.channel_dim,
self.min_positive, self.max_positive,
self.max_factor, self.min_abs,
self.max_abs)
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function):
@ -338,6 +424,7 @@ class DoubleSwishFunction(torch.autograd.Function):
= double_swish(x) * (1-s(x)) + s(x)
... so we just need to remember s(x) but not x itself.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
x = x.detach()
@ -349,7 +436,8 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
s, y = ctx.saved_tensors
return (y * (1-s) + s) * y_grad
return (y * (1 - s) + s) * y_grad
class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
@ -359,8 +447,6 @@ class DoubleSwish(torch.nn.Module):
return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module):
r"""This is a modified version of nn.Embedding that introduces a learnable scale
on the parameters. Note: due to how we initialize it, it's best used with
@ -443,8 +529,13 @@ class ScaledEmbedding(nn.Module):
[-0.1655, 0.9897, 0.0635]]])
"""
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
'scale_grad_by_freq', 'sparse']
__constants__ = [
"num_embeddings",
"embedding_dim",
"padding_idx",
"scale_grad_by_freq",
"sparse",
]
num_embeddings: int
embedding_dim: int
@ -453,18 +544,27 @@ class ScaledEmbedding(nn.Module):
weight: Tensor
sparse: bool
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
initial_speed: float = 1.0) -> None:
initial_speed: float = 1.0,
) -> None:
super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq
@ -475,11 +575,10 @@ class ScaledEmbedding(nn.Module):
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters(initial_speed)
def reset_parameters(self, initial_speed: float = 1.0) -> None:
std = 0.1 / initial_speed
nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
if self.padding_idx is not None:
with torch.no_grad():
@ -489,36 +588,53 @@ class ScaledEmbedding(nn.Module):
F = torch.nn.functional
scale = self.scale.exp()
if input.numel() < self.num_embeddings:
return F.embedding(
input, self.weight, self.padding_idx,
None, 2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq, self.sparse) * scale
return (
F.embedding(
input,
self.weight,
self.padding_idx,
None,
2.0, # None, 2.0 relate to normalization
self.scale_grad_by_freq,
self.sparse,
)
* scale
)
else:
return F.embedding(
input, self.weight * scale, self.padding_idx,
None, 2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq, self.sparse)
input,
self.weight * scale,
self.padding_idx,
None,
2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq,
self.sparse,
)
def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}, scale={scale}'
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
if self.padding_idx is not None:
s += ', padding_idx={padding_idx}'
s += ", padding_idx={padding_idx}"
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
s += ", scale_grad_by_freq={scale_grad_by_freq}"
if self.sparse is not False:
s += ', sparse=True'
s += ", sparse=True"
return s.format(**self.__dict__)
def _test_activation_balancer_sign():
channel_dim = 0
probs = torch.arange(0, 1, 0.01)
N = 1000
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
x = x.detach()
x.requires_grad = True
m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95,
max_factor=0.2, min_abs=0.0)
m = ActivationBalancer(
channel_dim=0,
min_positive=0.05,
max_positive=0.95,
max_factor=0.2,
min_abs=0.0,
)
y_grad = torch.sign(torch.randn(probs.numel(), N))
@ -528,17 +644,23 @@ def _test_activation_balancer_sign():
print("_test_activation_balancer_sign: y grad = ", y_grad)
print("_test_activation_balancer_sign: x grad = ", x.grad)
def _test_activation_balancer_magnitude():
channel_dim = 0
magnitudes = torch.arange(0, 1, 0.01)
N = 1000
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
-1
)
x = x.detach()
x.requires_grad = True
m = ActivationBalancer(channel_dim=0,
min_positive=0.0, max_positive=1.0,
m = ActivationBalancer(
channel_dim=0,
min_positive=0.0,
max_positive=1.0,
max_factor=0.2,
min_abs=0.2, max_abs=0.8)
min_abs=0.2,
max_abs=0.8,
)
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
@ -558,8 +680,8 @@ def _test_basic_norm():
y = m(x)
assert y.shape == x.shape
x_rms = (x**2).mean().sqrt()
y_rms = (y**2).mean().sqrt()
x_rms = (x ** 2).mean().sqrt()
y_rms = (y ** 2).mean().sqrt()
print("x rms = ", x_rms)
print("y rms = ", y_rms)
assert y_rms < x_rms
@ -573,7 +695,7 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x)
if __name__ == '__main__':
if __name__ == "__main__":
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()

View File

@ -45,16 +45,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse
import logging
import math
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
import optim # from .
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
@ -65,27 +64,24 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eve, Eden
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall import diagnostics
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_parser():
parser = argparse.ArgumentParser(
@ -168,7 +164,7 @@ def get_parser():
type=float,
default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this."""
We suggest not to change this.""",
)
parser.add_argument(
@ -176,7 +172,7 @@ def get_parser():
type=float,
default=6,
help="""Number of epochs that affects how rapidly the learning rate decreases.
"""
""",
)
parser.add_argument(
@ -510,7 +506,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
warmup: float = 1.0
warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -557,18 +553,24 @@ def compute_loss(
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (0.0 if warmup < 1.0 else
(0.1 if warmup > 1.0 and warmup < 2.0 else
1.0))
loss = (params.simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss)
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -675,7 +677,7 @@ def train_one_epoch(
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step)
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -691,8 +693,10 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5:
return
if (params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0):
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
@ -813,18 +817,19 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Eve(
model.parameters(),
lr=params.initial_lr)
optimizer = Eve(model.parameters(), lr=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None:
if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
@ -834,7 +839,6 @@ def run(rank, world_size, args):
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
@ -889,7 +893,6 @@ def run(rank, world_size, args):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = scheduler.get_last_lr()[0]
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
@ -956,7 +959,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup = 0.0
warmup=0.0,
)
loss.backward()
optimizer.step()

View File

@ -486,7 +486,9 @@ def modified_beam_search(
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
topk_hyp_indexes = torch.div(
topk_indexes, vocab_size, rounding_mode="trunc"
)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()

View File

@ -29,11 +29,11 @@ from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
# use duck typing for LRScheduler since we have different possibilities, see
# our class LRScheduler.
LRSchedulerType = object
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],