Scale down modules at initialization

This commit is contained in:
Daniel Povey 2022-05-22 11:56:59 +08:00
parent 5d57dd3930
commit 56d9928934
5 changed files with 203 additions and 35 deletions

View File

@ -170,22 +170,25 @@ class ConformerEncoderLayer(nn.Module):
)
self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
ScaledLinear(dim_feedforward, d_model,
initial_scale=0.05),
)
self.feed_forward_macaron = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
ScaledLinear(dim_feedforward, d_model,
initial_scale=0.05),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module = ConvolutionModule(d_model,
cnn_module_kernel)
self.norm_final = BasicNorm(d_model)
@ -435,13 +438,13 @@ class RelPositionMultiheadAttention(nn.Module):
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.25
embed_dim, embed_dim, bias=True, initial_scale=0.05
)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
@ -457,8 +460,8 @@ class RelPositionMultiheadAttention(nn.Module):
return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None:
nn.init.uniform_(self.pos_bias_u, -0.2, 0.2)
nn.init.uniform_(self.pos_bias_v, -0.2, 0.2)
nn.init.uniform_(self.pos_bias_u, -0.05, 0.05)
nn.init.uniform_(self.pos_bias_v, -0.05, 0.05)
def forward(
self,
@ -518,11 +521,11 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.get_weight(),
self.in_proj.get_bias(),
self.in_proj.weight,
self.in_proj.bias,
self.dropout,
self.out_proj.get_weight(),
self.out_proj.get_bias(),
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
@ -852,7 +855,7 @@ class ConvolutionModule(nn.Module):
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = ScaledConv1d(
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
@ -878,7 +881,7 @@ class ConvolutionModule(nn.Module):
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
self.depthwise_conv = ScaledConv1d(
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
@ -901,7 +904,7 @@ class ConvolutionModule(nn.Module):
stride=1,
padding=0,
bias=bias,
initial_scale=0.25,
initial_scale=0.05,
)
def forward(self, x: Tensor) -> Tensor:
@ -969,7 +972,7 @@ class Conv2dSubsampling(nn.Module):
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
@ -977,7 +980,7 @@ class Conv2dSubsampling(nn.Module):
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
@ -985,7 +988,7 @@ class Conv2dSubsampling(nn.Module):
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
@ -994,7 +997,7 @@ class Conv2dSubsampling(nn.Module):
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(
self.out = nn.Linear(
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`

View File

@ -1 +0,0 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1,102 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim,
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
return embedding_out

View File

@ -1 +0,0 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1,66 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

View File

@ -19,7 +19,6 @@ import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
@ -63,10 +62,10 @@ class Transducer(nn.Module):
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
self.simple_am_proj = nn.Linear(
encoder_dim, vocab_size,
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
def forward(
self,

View File

@ -179,13 +179,9 @@ class ScaledLinear(nn.Linear):
with torch.no_grad():
self.weight[:] *= initial_scale
if self.bias is not None:
torch.nn.init.uniform_(self.bias, -0.2, 0.2)
def get_weight(self): # not needed any more but kept for back compatibility
return self.weight
def get_bias(self):
return self.bias
torch.nn.init.uniform_(self.bias,
-0.1 * initial_scale,
0.1 * initial_scale)
@ -201,7 +197,9 @@ class ScaledConv1d(nn.Conv1d):
with torch.no_grad():
self.weight[:] *= initial_scale
if self.bias is not None:
torch.nn.init.uniform_(self.bias, -0.2, 0.2)
torch.nn.init.uniform_(self.bias,
-0.1 * initial_scale,
0.1 * initial_scale)
def get_weight(self): # TODO: delete
return self.weight
@ -222,7 +220,9 @@ class ScaledConv2d(nn.Conv2d):
with torch.no_grad():
self.weight[:] *= initial_scale
if self.bias is not None:
torch.nn.init.uniform_(self.bias, -0.2, 0.2)
torch.nn.init.uniform_(self.bias,
-0.1 * initial_scale,
0.1 * initial_scale)
def get_weight(self):
return self.weight