Reduce speed of some components

This commit is contained in:
Daniel Povey 2022-03-30 11:46:23 +08:00
parent 7c46c3b0d4
commit 37ab0bcfa5
3 changed files with 64 additions and 22 deletions

View File

@ -967,16 +967,24 @@ class Conv2dSubsampling(nn.Module):
""" """
assert in_channels >= 7 assert in_channels >= 7
super().__init__() super().__init__()
# This initial_speed is to slightly slow down the relative speed of
# training during the warmup phase by increasing the magnitude of the
# initial parameter values. The intention is to allow us to
# use a higher lr_factor.
initial_speed = 0.5
self.conv = nn.Sequential( self.conv = nn.Sequential(
ScaledConv2d( ScaledConv2d(
in_channels=1, out_channels=layer1_channels, in_channels=1, out_channels=layer1_channels,
kernel_size=3, stride=2 kernel_size=3, stride=2,
initial_speed=initial_speed,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),
ScaledConv2d( ScaledConv2d(
in_channels=layer1_channels, out_channels=layer2_channels, in_channels=layer1_channels, out_channels=layer2_channels,
kernel_size=3, stride=2 kernel_size=3, stride=2,
initial_speed=initial_speed,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(channel_dim=1),
DoubleSwish(), DoubleSwish(),

View File

@ -55,10 +55,17 @@ class Decoder(nn.Module):
1 means bigram; 2 means trigram. n means (n+1)-gram. 1 means bigram; 2 means trigram. n means (n+1)-gram.
""" """
super().__init__() super().__init__()
# This initial_speed is to slightly slow down the relative speed of
# training during the warmup phase by increasing the magnitude of the
# initial parameter values. The intention is to allow us to
# use a higher lr_factor.
initial_speed = 0.5
self.embedding = ScaledEmbedding( self.embedding = ScaledEmbedding(
num_embeddings=vocab_size, num_embeddings=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
padding_idx=blank_id, padding_idx=blank_id,
initial_speed=initial_speed
) )
self.blank_id = blank_id self.blank_id = blank_id

View File

@ -134,13 +134,18 @@ class ScaledLinear(nn.Linear):
(affects the initialization of weight_scale and bias_scale). (affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is Another option, if you want to do something like this, is
to re-initialize the parameters. to re-initialize the parameters.
initial_speed: this affects how fast the parameter will
Note: it uses the default initialization for the weight and bias, learn near the start of training; you can set it to a
inherited from nn.Linear. For modules with small fan-in, this value less than one if you suspect that a module
may be larger than optimal. is contributing to instability near the start of training.
Nnote: regardless of the use of this option, it's best to
use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
""" """
def __init__(self, *args, def __init__(self, *args,
initial_scale: float = 1.0, initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs): **kwargs):
super(ScaledLinear, self).__init__(*args, **kwargs) super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
@ -150,10 +155,10 @@ class ScaledLinear(nn.Linear):
else: else:
self.register_parameter('bias_scale', None) self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in nn.Linear self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self): def _reset_parameters(self, initial_speed: float):
std = 0.01 std = 0.01 / initial_speed
a = (3 ** 0.5) * std a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a) nn.init.uniform_(self.weight, -a, a)
if self.bias is not None: if self.bias is not None:
@ -176,8 +181,11 @@ class ScaledLinear(nn.Linear):
class ScaledConv1d(nn.Conv1d): class ScaledConv1d(nn.Conv1d):
# See docs for ScaledLinear
def __init__(self, *args, def __init__(self, *args,
initial_scale=1.0, **kwargs): initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
super(ScaledConv1d, self).__init__(*args, **kwargs) super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
@ -185,10 +193,10 @@ class ScaledConv1d(nn.Conv1d):
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
def _reset_parameters(self): def _reset_parameters(self, initial_speed: float):
std = 0.01 std = 0.01 / initial_speed
a = (3 ** 0.5) * std a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a) nn.init.uniform_(self.weight, -a, a)
if self.bias is not None: if self.bias is not None:
@ -218,7 +226,11 @@ class ScaledConv1d(nn.Conv1d):
class ScaledConv2d(nn.Conv2d): class ScaledConv2d(nn.Conv2d):
def __init__(self, *args, initial_scale=1.0, **kwargs): # See docs for ScaledLinear
def __init__(self, *args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
super(ScaledConv2d, self).__init__(*args, **kwargs) super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
@ -226,10 +238,10 @@ class ScaledConv2d(nn.Conv2d):
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
def _reset_parameters(self): def _reset_parameters(self, initial_speed: float):
std = 0.01 std = 0.01 / initial_speed
a = (3 ** 0.5) * std a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a) nn.init.uniform_(self.weight, -a, a)
if self.bias is not None: if self.bias is not None:
@ -350,7 +362,11 @@ class DoubleSwish(torch.nn.Module):
class ScaledEmbedding(nn.Module): class ScaledEmbedding(nn.Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size. r"""This is a modified version of nn.Embedding that introduces a learnable scale
on the parameters. Note: due to how we initialize it, it's best used with
schedulers like Noam that have a warmup period.
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices. This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding The input to the module is a list of indices, and the output is the corresponding
@ -369,6 +385,15 @@ class ScaledEmbedding(nn.Module):
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
See Notes for more details regarding sparse gradients. See Notes for more details regarding sparse gradients.
initial_speed (float, optional): This affects how fast the parameter will
learn near the start of training; you can set it to a value less than
one if you suspect that a module is contributing to instability near
the start of training. Nnote: regardless of the use of this option,
it's best to use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
Attributes: Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)` initialized from :math:`\mathcal{N}(0, 1)`
@ -416,6 +441,7 @@ class ScaledEmbedding(nn.Module):
[ 0.1535, -2.0309, 0.9315], [ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]]) [-0.1655, 0.9897, 0.0635]]])
""" """
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
'scale_grad_by_freq', 'sparse'] 'scale_grad_by_freq', 'sparse']
@ -429,7 +455,8 @@ class ScaledEmbedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False, scale_grad_by_freq: bool = False,
sparse: bool = False) -> None: sparse: bool = False,
initial_speed: float = 1.0) -> None:
super(ScaledEmbedding, self).__init__() super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
@ -446,12 +473,12 @@ class ScaledEmbedding(nn.Module):
self.sparse = sparse self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters() self.reset_parameters(initial_speed)
def reset_parameters(self) -> None: def reset_parameters(self, initial_speed: float = 1.0) -> None:
std = 0.01 std = 0.01 / initial_speed
nn.init.normal_(self.weight, std=std) nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) nn.init.constant_(self.scale, torch.tensor(1.0/std).log())