diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index a8475c21e..0d3b0aa02 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -967,16 +967,24 @@ class Conv2dSubsampling(nn.Module): """ assert in_channels >= 7 super().__init__() + + # This initial_speed is to slightly slow down the relative speed of + # training during the warmup phase by increasing the magnitude of the + # initial parameter values. The intention is to allow us to + # use a higher lr_factor. + initial_speed = 0.5 self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, - kernel_size=3, stride=2 + kernel_size=3, stride=2, + initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer1_channels, out_channels=layer2_channels, - kernel_size=3, stride=2 + kernel_size=3, stride=2, + initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 13e45e03b..3470b647f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -55,10 +55,17 @@ class Decoder(nn.Module): 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() + + # This initial_speed is to slightly slow down the relative speed of + # training during the warmup phase by increasing the magnitude of the + # initial parameter values. The intention is to allow us to + # use a higher lr_factor. + initial_speed = 0.5 self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, + initial_speed=initial_speed ) self.blank_id = blank_id diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f2423492f..4c45205ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -134,13 +134,18 @@ class ScaledLinear(nn.Linear): (affects the initialization of weight_scale and bias_scale). Another option, if you want to do something like this, is to re-initialize the parameters. - - Note: it uses the default initialization for the weight and bias, - inherited from nn.Linear. For modules with small fan-in, this - may be larger than optimal. + initial_speed: this affects how fast the parameter will + learn near the start of training; you can set it to a + value less than one if you suspect that a module + is contributing to instability near the start of training. + Nnote: regardless of the use of this option, it's best to + use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. """ def __init__(self, *args, initial_scale: float = 1.0, + initial_speed: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -150,10 +155,10 @@ class ScaledLinear(nn.Linear): else: self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in nn.Linear + self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear - def _reset_parameters(self): - std = 0.01 + def _reset_parameters(self, initial_speed: float): + std = 0.01 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -176,8 +181,11 @@ class ScaledLinear(nn.Linear): class ScaledConv1d(nn.Conv1d): + # See docs for ScaledLinear def __init__(self, *args, - initial_scale=1.0, **kwargs): + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) @@ -185,10 +193,10 @@ class ScaledConv1d(nn.Conv1d): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class + self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class - def _reset_parameters(self): - std = 0.01 + def _reset_parameters(self, initial_speed: float): + std = 0.01 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -218,7 +226,11 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, initial_scale=1.0, **kwargs): + # See docs for ScaledLinear + def __init__(self, *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) @@ -226,10 +238,10 @@ class ScaledConv2d(nn.Conv2d): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class + self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class - def _reset_parameters(self): - std = 0.01 + def _reset_parameters(self, initial_speed: float): + std = 0.01 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -350,7 +362,11 @@ class DoubleSwish(torch.nn.Module): class ScaledEmbedding(nn.Module): - r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + r"""This is a modified version of nn.Embedding that introduces a learnable scale + on the parameters. Note: due to how we initialize it, it's best used with + schedulers like Noam that have a warmup period. + + It is a simple lookup table that stores embeddings of a fixed dictionary and size. This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding @@ -369,6 +385,15 @@ class ScaledEmbedding(nn.Module): sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. + initial_speed (float, optional): This affects how fast the parameter will + learn near the start of training; you can set it to a value less than + one if you suspect that a module is contributing to instability near + the start of training. Nnote: regardless of the use of this option, + it's best to use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. + + Attributes: weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) initialized from :math:`\mathcal{N}(0, 1)` @@ -416,6 +441,7 @@ class ScaledEmbedding(nn.Module): [ 0.1535, -2.0309, 0.9315], [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) + """ __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'scale_grad_by_freq', 'sparse'] @@ -429,7 +455,8 @@ class ScaledEmbedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + sparse: bool = False, + initial_speed: float = 1.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim @@ -446,12 +473,12 @@ class ScaledEmbedding(nn.Module): self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() + self.reset_parameters(initial_speed) - def reset_parameters(self) -> None: - std = 0.01 + def reset_parameters(self, initial_speed: float = 1.0) -> None: + std = 0.01 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log())