Reduce speed of some components

This commit is contained in:
Daniel Povey 2022-03-30 11:46:23 +08:00
parent 7c46c3b0d4
commit 37ab0bcfa5
3 changed files with 64 additions and 22 deletions

View File

@ -967,16 +967,24 @@ class Conv2dSubsampling(nn.Module):
"""
assert in_channels >= 7
super().__init__()
# This initial_speed is to slightly slow down the relative speed of
# training during the warmup phase by increasing the magnitude of the
# initial parameter values. The intention is to allow us to
# use a higher lr_factor.
initial_speed = 0.5
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1, out_channels=layer1_channels,
kernel_size=3, stride=2
kernel_size=3, stride=2,
initial_speed=initial_speed,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels, out_channels=layer2_channels,
kernel_size=3, stride=2
kernel_size=3, stride=2,
initial_speed=initial_speed,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),

View File

@ -55,10 +55,17 @@ class Decoder(nn.Module):
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
# This initial_speed is to slightly slow down the relative speed of
# training during the warmup phase by increasing the magnitude of the
# initial parameter values. The intention is to allow us to
# use a higher lr_factor.
initial_speed = 0.5
self.embedding = ScaledEmbedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
initial_speed=initial_speed
)
self.blank_id = blank_id

View File

@ -134,13 +134,18 @@ class ScaledLinear(nn.Linear):
(affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is
to re-initialize the parameters.
Note: it uses the default initialization for the weight and bias,
inherited from nn.Linear. For modules with small fan-in, this
may be larger than optimal.
initial_speed: this affects how fast the parameter will
learn near the start of training; you can set it to a
value less than one if you suspect that a module
is contributing to instability near the start of training.
Nnote: regardless of the use of this option, it's best to
use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
"""
def __init__(self, *args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
@ -150,10 +155,10 @@ class ScaledLinear(nn.Linear):
else:
self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self):
std = 0.01
def _reset_parameters(self, initial_speed: float):
std = 0.01 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
@ -176,8 +181,11 @@ class ScaledLinear(nn.Linear):
class ScaledConv1d(nn.Conv1d):
# See docs for ScaledLinear
def __init__(self, *args,
initial_scale=1.0, **kwargs):
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
@ -185,10 +193,10 @@ class ScaledConv1d(nn.Conv1d):
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
def _reset_parameters(self):
std = 0.01
def _reset_parameters(self, initial_speed: float):
std = 0.01 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
@ -218,7 +226,11 @@ class ScaledConv1d(nn.Conv1d):
class ScaledConv2d(nn.Conv2d):
def __init__(self, *args, initial_scale=1.0, **kwargs):
# See docs for ScaledLinear
def __init__(self, *args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs):
super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
@ -226,10 +238,10 @@ class ScaledConv2d(nn.Conv2d):
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
def _reset_parameters(self):
std = 0.01
def _reset_parameters(self, initial_speed: float):
std = 0.01 / initial_speed
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
@ -350,7 +362,11 @@ class DoubleSwish(torch.nn.Module):
class ScaledEmbedding(nn.Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
r"""This is a modified version of nn.Embedding that introduces a learnable scale
on the parameters. Note: due to how we initialize it, it's best used with
schedulers like Noam that have a warmup period.
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
@ -369,6 +385,15 @@ class ScaledEmbedding(nn.Module):
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
See Notes for more details regarding sparse gradients.
initial_speed (float, optional): This affects how fast the parameter will
learn near the start of training; you can set it to a value less than
one if you suspect that a module is contributing to instability near
the start of training. Nnote: regardless of the use of this option,
it's best to use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)`
@ -416,6 +441,7 @@ class ScaledEmbedding(nn.Module):
[ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
"""
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
'scale_grad_by_freq', 'sparse']
@ -429,7 +455,8 @@ class ScaledEmbedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> None:
sparse: bool = False,
initial_speed: float = 1.0) -> None:
super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
@ -446,12 +473,12 @@ class ScaledEmbedding(nn.Module):
self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
self.reset_parameters(initial_speed)
def reset_parameters(self) -> None:
std = 0.01
def reset_parameters(self, initial_speed: float = 1.0) -> None:
std = 0.01 / initial_speed
nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())