From 56d9928934dfb7247aae1f0ba71c71747504fea6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 22 May 2022 11:56:59 +0800 Subject: [PATCH] Scale down modules at initialization --- .../pruned_transducer_stateless4/conformer.py | 45 ++++---- .../pruned_transducer_stateless4/decoder.py | 103 +++++++++++++++++- .../pruned_transducer_stateless4/joiner.py | 67 +++++++++++- .../ASR/pruned_transducer_stateless4/model.py | 5 +- .../pruned_transducer_stateless4/scaling.py | 18 +-- 5 files changed, 203 insertions(+), 35 deletions(-) mode change 120000 => 100644 egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py mode change 120000 => 100644 egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py index 33ca51743..8fd72cd33 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py @@ -170,22 +170,25 @@ class ConformerEncoderLayer(nn.Module): ) self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), + nn.Linear(d_model, dim_feedforward), ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ScaledLinear(dim_feedforward, d_model, + initial_scale=0.05), ) self.feed_forward_macaron = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), + nn.Linear(d_model, dim_feedforward), ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ScaledLinear(dim_feedforward, d_model, + initial_scale=0.05), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = ConvolutionModule(d_model, + cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -435,13 +438,13 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) self.out_proj = ScaledLinear( - embed_dim, embed_dim, bias=True, initial_scale=0.25 + embed_dim, embed_dim, bias=True, initial_scale=0.05 ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) @@ -457,8 +460,8 @@ class RelPositionMultiheadAttention(nn.Module): return self.pos_bias_v * self.pos_bias_v_scale.exp() def _reset_parameters(self) -> None: - nn.init.uniform_(self.pos_bias_u, -0.2, 0.2) - nn.init.uniform_(self.pos_bias_v, -0.2, 0.2) + nn.init.uniform_(self.pos_bias_u, -0.05, 0.05) + nn.init.uniform_(self.pos_bias_v, -0.05, 0.05) def forward( self, @@ -518,11 +521,11 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb, self.embed_dim, self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), + self.in_proj.weight, + self.in_proj.bias, self.dropout, - self.out_proj.get_weight(), - self.out_proj.get_bias(), + self.out_proj.weight, + self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -852,7 +855,7 @@ class ConvolutionModule(nn.Module): # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = ScaledConv1d( + self.pointwise_conv1 = nn.Conv1d( channels, 2 * channels, kernel_size=1, @@ -878,7 +881,7 @@ class ConvolutionModule(nn.Module): channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 ) - self.depthwise_conv = ScaledConv1d( + self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, @@ -901,7 +904,7 @@ class ConvolutionModule(nn.Module): stride=1, padding=0, bias=bias, - initial_scale=0.25, + initial_scale=0.05, ) def forward(self, x: Tensor) -> Tensor: @@ -969,7 +972,7 @@ class Conv2dSubsampling(nn.Module): super().__init__() self.conv = nn.Sequential( - ScaledConv2d( + nn.Conv2d( in_channels=1, out_channels=layer1_channels, kernel_size=3, @@ -977,7 +980,7 @@ class Conv2dSubsampling(nn.Module): ), ActivationBalancer(channel_dim=1), DoubleSwish(), - ScaledConv2d( + nn.Conv2d( in_channels=layer1_channels, out_channels=layer2_channels, kernel_size=3, @@ -985,7 +988,7 @@ class Conv2dSubsampling(nn.Module): ), ActivationBalancer(channel_dim=1), DoubleSwish(), - ScaledConv2d( + nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, @@ -994,7 +997,7 @@ class Conv2dSubsampling(nn.Module): ActivationBalancer(channel_dim=1), DoubleSwish(), ) - self.out = ScaledLinear( + self.out = nn.Linear( layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels ) # set learn_eps=False because out_norm is preceded by `out`, and `out` diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py deleted file mode 120000 index 0793c5709..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py new file mode 100644 index 000000000..a1c755d73 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py @@ -0,0 +1,102 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + padding_idx=blank_id, + ) + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim, + bias=False, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + embedding_out = self.embedding(y) + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + return embedding_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py deleted file mode 120000 index 815fd4bb6..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py new file mode 100644 index 000000000..afcd690e9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py @@ -0,0 +1,66 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + super().__init__() + + self.encoder_proj = nn.Linear(encoder_dim, joiner_dim) + self.decoder_proj = nn.Linear(decoder_dim, joiner_dim) + self.output_linear = nn.Linear(joiner_dim, vocab_size) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape[:-1] == decoder_out.shape[:-1] + + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) + else: + logit = encoder_out + decoder_out + + logit = self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/model.py b/egs/librispeech/ASR/pruned_transducer_stateless4/model.py index 79ec24d16..24898ed09 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/model.py @@ -19,7 +19,6 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface -from scaling import ScaledLinear from icefall.utils import add_sos @@ -63,10 +62,10 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( + self.simple_am_proj = nn.Linear( encoder_dim, vocab_size, ) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) def forward( self, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py index f0a1ec0ca..7ba71a94a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py @@ -179,13 +179,9 @@ class ScaledLinear(nn.Linear): with torch.no_grad(): self.weight[:] *= initial_scale if self.bias is not None: - torch.nn.init.uniform_(self.bias, -0.2, 0.2) - - def get_weight(self): # not needed any more but kept for back compatibility - return self.weight - - def get_bias(self): - return self.bias + torch.nn.init.uniform_(self.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) @@ -201,7 +197,9 @@ class ScaledConv1d(nn.Conv1d): with torch.no_grad(): self.weight[:] *= initial_scale if self.bias is not None: - torch.nn.init.uniform_(self.bias, -0.2, 0.2) + torch.nn.init.uniform_(self.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) def get_weight(self): # TODO: delete return self.weight @@ -222,7 +220,9 @@ class ScaledConv2d(nn.Conv2d): with torch.no_grad(): self.weight[:] *= initial_scale if self.bias is not None: - torch.nn.init.uniform_(self.bias, -0.2, 0.2) + torch.nn.init.uniform_(self.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) def get_weight(self): return self.weight