From 633213424d24de73c09170d68b138ef830ed3cbd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 12:42:59 +0800 Subject: [PATCH] Rework of initialization --- .../ASR/conformer_ctc/subsampling.py | 70 ++++++++++++++++--- .../ASR/transducer_stateless/conformer.py | 16 ++--- .../ASR/transducer_stateless/decoder.py | 64 +++-------------- .../ASR/transducer_stateless/train.py | 3 +- 4 files changed, 78 insertions(+), 75 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 50a9db41a..5e44c5b29 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -62,13 +62,6 @@ class Conv2dSubsampling(nn.Module): self.out_norm = BasicNorm(odim, learn_eps=False) # constrain median of output to be close to zero. self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) - self._reset_parameters() - - def _reset_parameters(self): - # init weights with smaller than default variance, because otherwise - # they learn too slowly in relative terms (assuming we're training with adam). - nn.init.normal_(self.conv[0].weight, std=0.05) - nn.init.constant_(self.conv[0].bias, 0.0) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -406,8 +399,36 @@ class BasicNorm(torch.nn.Module): return x * scales + + class ScaledLinear(nn.Linear): - def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): + """ + A modified version of nn.Linear where the parameters are scaled before + use, via: + weight = self.weight * (self.weight_scale * self.scale_speed).exp() + bias = self.bias * (self.bias_scale * self.scale_speed).exp() + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + scale_speed: a factor that affects how fast the weight_scale + and bias_scale learn; this value is suitable for Adam-type + optimizers. + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + + Note: it uses the default initialization for the weight and bias, + inherited from nn.Linear. For modules with small fan-in, this + may be larger than optimal. + """ + def __init__(self, *args, + scale_speed: float = 5.0, + initial_scale: float = 1.0, + **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = (torch.tensor(initial_scale).log() / scale_speed) self.weight_scale = nn.Parameter(initial_scale.clone().detach()) @@ -417,6 +438,17 @@ class ScaledLinear(nn.Linear): else: self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in nn.Linear + + def _reset_parameters(self): + nn.init.normal_(self.weight, std=0.05) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -425,7 +457,6 @@ class ScaledLinear(nn.Linear): return (None if self.bias is None else self.bias * (self.bias_scale * self.scale_speed).exp()) - def forward(self, input: Tensor) -> Tensor: return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) @@ -442,6 +473,17 @@ class ScaledConv1d(nn.Conv1d): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + nn.init.normal_(self.weight, std=0.05) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -471,6 +513,16 @@ class ScaledConv2d(nn.Conv2d): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + nn.init.normal_(self.weight, std=0.05) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) def get_weight(self): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index cc1ae53a1..0b89fdcd2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module): DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ScaledLinear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( @@ -170,7 +170,7 @@ class ConformerEncoderLayer(nn.Module): DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ScaledLinear(dim_feedforward, d_model), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -423,7 +423,7 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) @@ -434,7 +434,6 @@ class RelPositionMultiheadAttention(nn.Module): self.scale_speed = scale_speed self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) - self._reset_parameters() def _pos_bias_u(self): @@ -444,12 +443,8 @@ class RelPositionMultiheadAttention(nn.Module): return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) + nn.init.normal_(self.pos_bias_u, std=0.05) + nn.init.normal_(self.pos_bias_v, std=0.05) def forward( self, @@ -891,7 +886,6 @@ class ConvolutionModule(nn.Module): stride=1, padding=0, bias=bias, - initial_scale=0.25 ) def forward(self, x: Tensor) -> Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index bc4bcb3f6..838b6794d 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -183,7 +183,7 @@ class ScaledEmbedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None, + sparse: bool = False, scale_speed: float = 5.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings @@ -198,19 +198,18 @@ class ScaledEmbedding(nn.Module): self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.tensor(embedding_dim**0.5).log() / scale_speed) - - if _weight is None: - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() - else: - assert list(_weight.shape) == [num_embeddings, embedding_dim], \ - 'Shape of weight does not match num_embeddings and embedding_dim' - self.weight = nn.Parameter(_weight) + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=self.embedding_dim**-0.5) + nn.init.normal_(self.weight, std=0.05) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) @@ -228,7 +227,6 @@ class ScaledEmbedding(nn.Module): None, 2.0, # None, 2.0 relates to normalization self.scale_grad_by_freq, self.sparse) - def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' if self.padding_idx is not None: @@ -238,45 +236,3 @@ class ScaledEmbedding(nn.Module): if self.sparse is not False: s += ', sparse=True' return s.format(**self.__dict__) - - @classmethod - def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, - max_norm=None, norm_type=2., scale_grad_by_freq=False, - sparse=False): - r"""Creates Embedding instance from given 2-dimensional FloatTensor. - - Args: - embeddings (Tensor): FloatTensor containing weights for the Embedding. - First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. - freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. - Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` - padding_idx (int, optional): See module initialization documentation. - max_norm (float, optional): See module initialization documentation. - norm_type (float, optional): See module initialization documentation. Default ``2``. - scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. - sparse (bool, optional): See module initialization documentation. - - Examples:: - - >>> # FloatTensor containing pretrained weights - >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) - >>> embedding = nn.Embedding.from_pretrained(weight) - >>> # Get embeddings for index 1 - >>> input = torch.LongTensor([1]) - >>> embedding(input) - tensor([[ 4.0000, 5.1000, 6.3000]]) - """ - assert embeddings.dim() == 2, \ - 'Embeddings parameter is expected to be 2-dimensional' - rows, cols = embeddings.shape - embedding = cls( - num_embeddings=rows, - embedding_dim=cols, - _weight=embeddings, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - embedding.weight.requires_grad = not freeze - return embedding diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 41fdb4ef3..8f2157715 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,8 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean", + # was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization.. + default="transducer_stateless/randcombine1_expscale3_rework2d" help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved