mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Scale down modules at initialization
This commit is contained in:
parent
5d57dd3930
commit
56d9928934
@ -170,22 +170,25 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
ActivationBalancer(channel_dim=-1),
|
ActivationBalancer(channel_dim=-1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model,
|
||||||
|
initial_scale=0.05),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
ActivationBalancer(channel_dim=-1),
|
ActivationBalancer(channel_dim=-1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model,
|
||||||
|
initial_scale=0.05),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model,
|
||||||
|
cnn_module_kernel)
|
||||||
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
@ -435,13 +438,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.head_dim * num_heads == self.embed_dim
|
self.head_dim * num_heads == self.embed_dim
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
embed_dim, embed_dim, bias=True, initial_scale=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||||
# these two learnable bias are used in matrix c and matrix d
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
@ -457,8 +460,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
||||||
|
|
||||||
def _reset_parameters(self) -> None:
|
def _reset_parameters(self) -> None:
|
||||||
nn.init.uniform_(self.pos_bias_u, -0.2, 0.2)
|
nn.init.uniform_(self.pos_bias_u, -0.05, 0.05)
|
||||||
nn.init.uniform_(self.pos_bias_v, -0.2, 0.2)
|
nn.init.uniform_(self.pos_bias_v, -0.05, 0.05)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -518,11 +521,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
pos_emb,
|
pos_emb,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.in_proj.get_weight(),
|
self.in_proj.weight,
|
||||||
self.in_proj.get_bias(),
|
self.in_proj.bias,
|
||||||
self.dropout,
|
self.dropout,
|
||||||
self.out_proj.get_weight(),
|
self.out_proj.weight,
|
||||||
self.out_proj.get_bias(),
|
self.out_proj.bias,
|
||||||
training=self.training,
|
training=self.training,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
@ -852,7 +855,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
self.pointwise_conv1 = ScaledConv1d(
|
self.pointwise_conv1 = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
2 * channels,
|
2 * channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
@ -878,7 +881,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.depthwise_conv = ScaledConv1d(
|
self.depthwise_conv = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -901,7 +904,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
initial_scale=0.25,
|
initial_scale=0.05,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
@ -969,7 +972,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
ScaledConv2d(
|
nn.Conv2d(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=layer1_channels,
|
out_channels=layer1_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
@ -977,7 +980,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
nn.Conv2d(
|
||||||
in_channels=layer1_channels,
|
in_channels=layer1_channels,
|
||||||
out_channels=layer2_channels,
|
out_channels=layer2_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
@ -985,7 +988,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
nn.Conv2d(
|
||||||
in_channels=layer2_channels,
|
in_channels=layer2_channels,
|
||||||
out_channels=layer3_channels,
|
out_channels=layer3_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
@ -994,7 +997,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
self.out = ScaledLinear(
|
self.out = nn.Linear(
|
||||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||||
)
|
)
|
||||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
|
@ -1 +0,0 @@
|
|||||||
../pruned_transducer_stateless2/decoder.py
|
|
102
egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py
Normal file
102
egs/librispeech/ASR/pruned_transducer_stateless4/decoder.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""This class modifies the stateless decoder from the following paper:
|
||||||
|
|
||||||
|
RNN-transducer with stateless prediction network
|
||||||
|
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||||
|
|
||||||
|
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||||
|
network. Different from the above paper, it adds an extra Conv1d
|
||||||
|
right after the embedding layer.
|
||||||
|
|
||||||
|
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
decoder_dim: int,
|
||||||
|
blank_id: int,
|
||||||
|
context_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
vocab_size:
|
||||||
|
Number of tokens of the modeling unit including blank.
|
||||||
|
decoder_dim:
|
||||||
|
Dimension of the input embedding, and of the decoder output.
|
||||||
|
blank_id:
|
||||||
|
The ID of the blank symbol.
|
||||||
|
context_size:
|
||||||
|
Number of previous words to use to predict the next word.
|
||||||
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(
|
||||||
|
num_embeddings=vocab_size,
|
||||||
|
embedding_dim=decoder_dim,
|
||||||
|
padding_idx=blank_id,
|
||||||
|
)
|
||||||
|
self.blank_id = blank_id
|
||||||
|
|
||||||
|
assert context_size >= 1, context_size
|
||||||
|
self.context_size = context_size
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
if context_size > 1:
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
in_channels=decoder_dim,
|
||||||
|
out_channels=decoder_dim,
|
||||||
|
kernel_size=context_size,
|
||||||
|
padding=0,
|
||||||
|
groups=decoder_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, U).
|
||||||
|
need_pad:
|
||||||
|
True to left pad the input. Should be True during training.
|
||||||
|
False to not pad the input. Should be False during inference.
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
|
"""
|
||||||
|
y = y.to(torch.int64)
|
||||||
|
embedding_out = self.embedding(y)
|
||||||
|
if self.context_size > 1:
|
||||||
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
if need_pad is True:
|
||||||
|
embedding_out = F.pad(
|
||||||
|
embedding_out, pad=(self.context_size - 1, 0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# During inference time, there is no need to do extra padding
|
||||||
|
# as we only need one output
|
||||||
|
assert embedding_out.size(-1) == self.context_size
|
||||||
|
embedding_out = self.conv(embedding_out)
|
||||||
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
embedding_out = F.relu(embedding_out)
|
||||||
|
return embedding_out
|
@ -1 +0,0 @@
|
|||||||
../pruned_transducer_stateless2/joiner.py
|
|
66
egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py
Normal file
66
egs/librispeech/ASR/pruned_transducer_stateless4/joiner.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Joiner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_dim: int,
|
||||||
|
decoder_dim: int,
|
||||||
|
joiner_dim: int,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
|
||||||
|
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
|
||||||
|
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
project_input: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||||
|
decoder_out:
|
||||||
|
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||||
|
project_input:
|
||||||
|
If true, apply input projections encoder_proj and decoder_proj.
|
||||||
|
If this is false, it is the user's responsibility to do this
|
||||||
|
manually.
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||||
|
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||||
|
|
||||||
|
if project_input:
|
||||||
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||||
|
decoder_out
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
|
||||||
|
return logit
|
@ -19,7 +19,6 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
@ -63,10 +62,10 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
self.simple_am_proj = ScaledLinear(
|
self.simple_am_proj = nn.Linear(
|
||||||
encoder_dim, vocab_size,
|
encoder_dim, vocab_size,
|
||||||
)
|
)
|
||||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -179,13 +179,9 @@ class ScaledLinear(nn.Linear):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[:] *= initial_scale
|
self.weight[:] *= initial_scale
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
torch.nn.init.uniform_(self.bias, -0.2, 0.2)
|
torch.nn.init.uniform_(self.bias,
|
||||||
|
-0.1 * initial_scale,
|
||||||
def get_weight(self): # not needed any more but kept for back compatibility
|
0.1 * initial_scale)
|
||||||
return self.weight
|
|
||||||
|
|
||||||
def get_bias(self):
|
|
||||||
return self.bias
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -201,7 +197,9 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[:] *= initial_scale
|
self.weight[:] *= initial_scale
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
torch.nn.init.uniform_(self.bias, -0.2, 0.2)
|
torch.nn.init.uniform_(self.bias,
|
||||||
|
-0.1 * initial_scale,
|
||||||
|
0.1 * initial_scale)
|
||||||
|
|
||||||
def get_weight(self): # TODO: delete
|
def get_weight(self): # TODO: delete
|
||||||
return self.weight
|
return self.weight
|
||||||
@ -222,7 +220,9 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[:] *= initial_scale
|
self.weight[:] *= initial_scale
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
torch.nn.init.uniform_(self.bias, -0.2, 0.2)
|
torch.nn.init.uniform_(self.bias,
|
||||||
|
-0.1 * initial_scale,
|
||||||
|
0.1 * initial_scale)
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.weight
|
return self.weight
|
||||||
|
Loading…
x
Reference in New Issue
Block a user