mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Reduce speed of some components
This commit is contained in:
parent
7c46c3b0d4
commit
37ab0bcfa5
@ -967,16 +967,24 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert in_channels >= 7
|
assert in_channels >= 7
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# This initial_speed is to slightly slow down the relative speed of
|
||||||
|
# training during the warmup phase by increasing the magnitude of the
|
||||||
|
# initial parameter values. The intention is to allow us to
|
||||||
|
# use a higher lr_factor.
|
||||||
|
initial_speed = 0.5
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=1, out_channels=layer1_channels,
|
in_channels=1, out_channels=layer1_channels,
|
||||||
kernel_size=3, stride=2
|
kernel_size=3, stride=2,
|
||||||
|
initial_speed=initial_speed,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=layer1_channels, out_channels=layer2_channels,
|
in_channels=layer1_channels, out_channels=layer2_channels,
|
||||||
kernel_size=3, stride=2
|
kernel_size=3, stride=2,
|
||||||
|
initial_speed=initial_speed,
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
|
@ -55,10 +55,17 @@ class Decoder(nn.Module):
|
|||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# This initial_speed is to slightly slow down the relative speed of
|
||||||
|
# training during the warmup phase by increasing the magnitude of the
|
||||||
|
# initial parameter values. The intention is to allow us to
|
||||||
|
# use a higher lr_factor.
|
||||||
|
initial_speed = 0.5
|
||||||
self.embedding = ScaledEmbedding(
|
self.embedding = ScaledEmbedding(
|
||||||
num_embeddings=vocab_size,
|
num_embeddings=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
|
initial_speed=initial_speed
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
|
||||||
|
@ -134,13 +134,18 @@ class ScaledLinear(nn.Linear):
|
|||||||
(affects the initialization of weight_scale and bias_scale).
|
(affects the initialization of weight_scale and bias_scale).
|
||||||
Another option, if you want to do something like this, is
|
Another option, if you want to do something like this, is
|
||||||
to re-initialize the parameters.
|
to re-initialize the parameters.
|
||||||
|
initial_speed: this affects how fast the parameter will
|
||||||
Note: it uses the default initialization for the weight and bias,
|
learn near the start of training; you can set it to a
|
||||||
inherited from nn.Linear. For modules with small fan-in, this
|
value less than one if you suspect that a module
|
||||||
may be larger than optimal.
|
is contributing to instability near the start of training.
|
||||||
|
Nnote: regardless of the use of this option, it's best to
|
||||||
|
use schedulers like Noam that have a warm-up period.
|
||||||
|
Alternatively you can set it to more than 1 if you want it to
|
||||||
|
initially train faster. Must be greater than 0.
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args,
|
def __init__(self, *args,
|
||||||
initial_scale: float = 1.0,
|
initial_scale: float = 1.0,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
@ -150,10 +155,10 @@ class ScaledLinear(nn.Linear):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter('bias_scale', None)
|
||||||
|
|
||||||
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
|
self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.01
|
std = 0.01 / initial_speed
|
||||||
a = (3 ** 0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
@ -176,8 +181,11 @@ class ScaledLinear(nn.Linear):
|
|||||||
|
|
||||||
|
|
||||||
class ScaledConv1d(nn.Conv1d):
|
class ScaledConv1d(nn.Conv1d):
|
||||||
|
# See docs for ScaledLinear
|
||||||
def __init__(self, *args,
|
def __init__(self, *args,
|
||||||
initial_scale=1.0, **kwargs):
|
initial_scale: float = 1.0,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
**kwargs):
|
||||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
@ -185,10 +193,10 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter('bias_scale', None)
|
||||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.01
|
std = 0.01 / initial_speed
|
||||||
a = (3 ** 0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
@ -218,7 +226,11 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
|
|
||||||
|
|
||||||
class ScaledConv2d(nn.Conv2d):
|
class ScaledConv2d(nn.Conv2d):
|
||||||
def __init__(self, *args, initial_scale=1.0, **kwargs):
|
# See docs for ScaledLinear
|
||||||
|
def __init__(self, *args,
|
||||||
|
initial_scale: float = 1.0,
|
||||||
|
initial_speed: float = 1.0,
|
||||||
|
**kwargs):
|
||||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||||
initial_scale = torch.tensor(initial_scale).log()
|
initial_scale = torch.tensor(initial_scale).log()
|
||||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
@ -226,10 +238,10 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter('bias_scale', None)
|
||||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.01
|
std = 0.01 / initial_speed
|
||||||
a = (3 ** 0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
nn.init.uniform_(self.weight, -a, a)
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
@ -350,7 +362,11 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ScaledEmbedding(nn.Module):
|
class ScaledEmbedding(nn.Module):
|
||||||
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
||||||
|
on the parameters. Note: due to how we initialize it, it's best used with
|
||||||
|
schedulers like Noam that have a warmup period.
|
||||||
|
|
||||||
|
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||||
|
|
||||||
This module is often used to store word embeddings and retrieve them using indices.
|
This module is often used to store word embeddings and retrieve them using indices.
|
||||||
The input to the module is a list of indices, and the output is the corresponding
|
The input to the module is a list of indices, and the output is the corresponding
|
||||||
@ -369,6 +385,15 @@ class ScaledEmbedding(nn.Module):
|
|||||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
||||||
See Notes for more details regarding sparse gradients.
|
See Notes for more details regarding sparse gradients.
|
||||||
|
|
||||||
|
initial_speed (float, optional): This affects how fast the parameter will
|
||||||
|
learn near the start of training; you can set it to a value less than
|
||||||
|
one if you suspect that a module is contributing to instability near
|
||||||
|
the start of training. Nnote: regardless of the use of this option,
|
||||||
|
it's best to use schedulers like Noam that have a warm-up period.
|
||||||
|
Alternatively you can set it to more than 1 if you want it to
|
||||||
|
initially train faster. Must be greater than 0.
|
||||||
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
||||||
initialized from :math:`\mathcal{N}(0, 1)`
|
initialized from :math:`\mathcal{N}(0, 1)`
|
||||||
@ -416,6 +441,7 @@ class ScaledEmbedding(nn.Module):
|
|||||||
[ 0.1535, -2.0309, 0.9315],
|
[ 0.1535, -2.0309, 0.9315],
|
||||||
[ 0.0000, 0.0000, 0.0000],
|
[ 0.0000, 0.0000, 0.0000],
|
||||||
[-0.1655, 0.9897, 0.0635]]])
|
[-0.1655, 0.9897, 0.0635]]])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
||||||
'scale_grad_by_freq', 'sparse']
|
'scale_grad_by_freq', 'sparse']
|
||||||
@ -429,7 +455,8 @@ class ScaledEmbedding(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||||
scale_grad_by_freq: bool = False,
|
scale_grad_by_freq: bool = False,
|
||||||
sparse: bool = False) -> None:
|
sparse: bool = False,
|
||||||
|
initial_speed: float = 1.0) -> None:
|
||||||
super(ScaledEmbedding, self).__init__()
|
super(ScaledEmbedding, self).__init__()
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
@ -446,12 +473,12 @@ class ScaledEmbedding(nn.Module):
|
|||||||
self.sparse = sparse
|
self.sparse = sparse
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||||
self.reset_parameters()
|
self.reset_parameters(initial_speed)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
||||||
std = 0.01
|
std = 0.01 / initial_speed
|
||||||
nn.init.normal_(self.weight, std=std)
|
nn.init.normal_(self.weight, std=std)
|
||||||
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user