mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Reduce speed of some components
This commit is contained in:
parent
7c46c3b0d4
commit
37ab0bcfa5
@ -967,16 +967,24 @@ class Conv2dSubsampling(nn.Module):
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
# This initial_speed is to slightly slow down the relative speed of
|
||||
# training during the warmup phase by increasing the magnitude of the
|
||||
# initial parameter values. The intention is to allow us to
|
||||
# use a higher lr_factor.
|
||||
initial_speed = 0.5
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1, out_channels=layer1_channels,
|
||||
kernel_size=3, stride=2
|
||||
kernel_size=3, stride=2,
|
||||
initial_speed=initial_speed,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels, out_channels=layer2_channels,
|
||||
kernel_size=3, stride=2
|
||||
kernel_size=3, stride=2,
|
||||
initial_speed=initial_speed,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
|
@ -55,10 +55,17 @@ class Decoder(nn.Module):
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# This initial_speed is to slightly slow down the relative speed of
|
||||
# training during the warmup phase by increasing the magnitude of the
|
||||
# initial parameter values. The intention is to allow us to
|
||||
# use a higher lr_factor.
|
||||
initial_speed = 0.5
|
||||
self.embedding = ScaledEmbedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=blank_id,
|
||||
initial_speed=initial_speed
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
|
||||
|
@ -134,13 +134,18 @@ class ScaledLinear(nn.Linear):
|
||||
(affects the initialization of weight_scale and bias_scale).
|
||||
Another option, if you want to do something like this, is
|
||||
to re-initialize the parameters.
|
||||
|
||||
Note: it uses the default initialization for the weight and bias,
|
||||
inherited from nn.Linear. For modules with small fan-in, this
|
||||
may be larger than optimal.
|
||||
initial_speed: this affects how fast the parameter will
|
||||
learn near the start of training; you can set it to a
|
||||
value less than one if you suspect that a module
|
||||
is contributing to instability near the start of training.
|
||||
Nnote: regardless of the use of this option, it's best to
|
||||
use schedulers like Noam that have a warm-up period.
|
||||
Alternatively you can set it to more than 1 if you want it to
|
||||
initially train faster. Must be greater than 0.
|
||||
"""
|
||||
def __init__(self, *args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs):
|
||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
@ -150,10 +155,10 @@ class ScaledLinear(nn.Linear):
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
|
||||
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
|
||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.01
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.01 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
@ -176,8 +181,11 @@ class ScaledLinear(nn.Linear):
|
||||
|
||||
|
||||
class ScaledConv1d(nn.Conv1d):
|
||||
# See docs for ScaledLinear
|
||||
def __init__(self, *args,
|
||||
initial_scale=1.0, **kwargs):
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs):
|
||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
@ -185,10 +193,10 @@ class ScaledConv1d(nn.Conv1d):
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.01
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.01 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
@ -218,7 +226,11 @@ class ScaledConv1d(nn.Conv1d):
|
||||
|
||||
|
||||
class ScaledConv2d(nn.Conv2d):
|
||||
def __init__(self, *args, initial_scale=1.0, **kwargs):
|
||||
# See docs for ScaledLinear
|
||||
def __init__(self, *args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs):
|
||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
@ -226,10 +238,10 @@ class ScaledConv2d(nn.Conv2d):
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.01
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.01 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
@ -350,7 +362,11 @@ class DoubleSwish(torch.nn.Module):
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
||||
on the parameters. Note: due to how we initialize it, it's best used with
|
||||
schedulers like Noam that have a warmup period.
|
||||
|
||||
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
|
||||
This module is often used to store word embeddings and retrieve them using indices.
|
||||
The input to the module is a list of indices, and the output is the corresponding
|
||||
@ -369,6 +385,15 @@ class ScaledEmbedding(nn.Module):
|
||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
||||
See Notes for more details regarding sparse gradients.
|
||||
|
||||
initial_speed (float, optional): This affects how fast the parameter will
|
||||
learn near the start of training; you can set it to a value less than
|
||||
one if you suspect that a module is contributing to instability near
|
||||
the start of training. Nnote: regardless of the use of this option,
|
||||
it's best to use schedulers like Noam that have a warm-up period.
|
||||
Alternatively you can set it to more than 1 if you want it to
|
||||
initially train faster. Must be greater than 0.
|
||||
|
||||
|
||||
Attributes:
|
||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
||||
initialized from :math:`\mathcal{N}(0, 1)`
|
||||
@ -416,6 +441,7 @@ class ScaledEmbedding(nn.Module):
|
||||
[ 0.1535, -2.0309, 0.9315],
|
||||
[ 0.0000, 0.0000, 0.0000],
|
||||
[-0.1655, 0.9897, 0.0635]]])
|
||||
|
||||
"""
|
||||
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
||||
'scale_grad_by_freq', 'sparse']
|
||||
@ -429,7 +455,8 @@ class ScaledEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False) -> None:
|
||||
sparse: bool = False,
|
||||
initial_speed: float = 1.0) -> None:
|
||||
super(ScaledEmbedding, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
@ -446,12 +473,12 @@ class ScaledEmbedding(nn.Module):
|
||||
self.sparse = sparse
|
||||
|
||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||
self.reset_parameters()
|
||||
self.reset_parameters(initial_speed)
|
||||
|
||||
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
std = 0.01
|
||||
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
||||
std = 0.01 / initial_speed
|
||||
nn.init.normal_(self.weight, std=std)
|
||||
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user