Rework of initialization

This commit is contained in:
Daniel Povey 2022-03-16 12:42:59 +08:00
parent 1331199530
commit 633213424d
4 changed files with 78 additions and 75 deletions

View File

@ -62,13 +62,6 @@ class Conv2dSubsampling(nn.Module):
self.out_norm = BasicNorm(odim, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
self._reset_parameters()
def _reset_parameters(self):
# init weights with smaller than default variance, because otherwise
# they learn too slowly in relative terms (assuming we're training with adam).
nn.init.normal_(self.conv[0].weight, std=0.05)
nn.init.constant_(self.conv[0].bias, 0.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -406,8 +399,36 @@ class BasicNorm(torch.nn.Module):
return x * scales
class ScaledLinear(nn.Linear):
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
"""
A modified version of nn.Linear where the parameters are scaled before
use, via:
weight = self.weight * (self.weight_scale * self.scale_speed).exp()
bias = self.bias * (self.bias_scale * self.scale_speed).exp()
Args:
Accepts the standard args and kwargs that nn.Linear accepts
e.g. in_features, out_features, bias=False.
scale_speed: a factor that affects how fast the weight_scale
and bias_scale learn; this value is suitable for Adam-type
optimizers.
initial_scale: you can override this if you want to increase
or decrease the initial magnitude of the module's output
(affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is
to re-initialize the parameters.
Note: it uses the default initialization for the weight and bias,
inherited from nn.Linear. For modules with small fan-in, this
may be larger than optimal.
"""
def __init__(self, *args,
scale_speed: float = 5.0,
initial_scale: float = 1.0,
**kwargs):
super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
@ -417,6 +438,17 @@ class ScaledLinear(nn.Linear):
else:
self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self):
nn.init.normal_(self.weight, std=0.05)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1]
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp()
@ -425,7 +457,6 @@ class ScaledLinear(nn.Linear):
return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp())
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(input, self.get_weight(),
self.get_bias())
@ -442,6 +473,17 @@ class ScaledConv1d(nn.Conv1d):
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class
def _reset_parameters(self):
nn.init.normal_(self.weight, std=0.05)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp()
@ -471,6 +513,16 @@ class ScaledConv2d(nn.Conv2d):
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class
def _reset_parameters(self):
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
nn.init.normal_(self.weight, std=0.05)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
def get_weight(self):

View File

@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module):
DerivBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
ScaledLinear(dim_feedforward, d_model),
)
self.feed_forward_macaron = nn.Sequential(
@ -170,7 +170,7 @@ class ConformerEncoderLayer(nn.Module):
DerivBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
ScaledLinear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
@ -423,7 +423,7 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
@ -434,7 +434,6 @@ class RelPositionMultiheadAttention(nn.Module):
self.scale_speed = scale_speed
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
@ -444,12 +443,8 @@ class RelPositionMultiheadAttention(nn.Module):
return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp()
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
nn.init.normal_(self.pos_bias_u, std=0.05)
nn.init.normal_(self.pos_bias_v, std=0.05)
def forward(
self,
@ -891,7 +886,6 @@ class ConvolutionModule(nn.Module):
stride=1,
padding=0,
bias=bias,
initial_scale=0.25
)
def forward(self, x: Tensor) -> Tensor:

View File

@ -183,7 +183,7 @@ class ScaledEmbedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None,
sparse: bool = False,
scale_speed: float = 5.0) -> None:
super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings
@ -198,19 +198,18 @@ class ScaledEmbedding(nn.Module):
self.scale_grad_by_freq = scale_grad_by_freq
self.scale_speed = scale_speed
self.scale = nn.Parameter(torch.tensor(embedding_dim**0.5).log() / scale_speed)
if _weight is None:
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
'Shape of weight does not match num_embeddings and embedding_dim'
self.weight = nn.Parameter(_weight)
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=self.embedding_dim**-0.5)
nn.init.normal_(self.weight, std=0.05)
nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed)
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
@ -228,7 +227,6 @@ class ScaledEmbedding(nn.Module):
None, 2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq, self.sparse)
def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}'
if self.padding_idx is not None:
@ -238,45 +236,3 @@ class ScaledEmbedding(nn.Module):
if self.sparse is not False:
s += ', sparse=True'
return s.format(**self.__dict__)
@classmethod
def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
max_norm=None, norm_type=2., scale_grad_by_freq=False,
sparse=False):
r"""Creates Embedding instance from given 2-dimensional FloatTensor.
Args:
embeddings (Tensor): FloatTensor containing weights for the Embedding.
First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
padding_idx (int, optional): See module initialization documentation.
max_norm (float, optional): See module initialization documentation.
norm_type (float, optional): See module initialization documentation. Default ``2``.
scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
sparse (bool, optional): See module initialization documentation.
Examples::
>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([1])
>>> embedding(input)
tensor([[ 4.0000, 5.1000, 6.3000]])
"""
assert embeddings.dim() == 2, \
'Embeddings parameter is expected to be 2-dimensional'
rows, cols = embeddings.shape
embedding = cls(
num_embeddings=rows,
embedding_dim=cols,
_weight=embeddings,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
embedding.weight.requires_grad = not freeze
return embedding

View File

@ -110,7 +110,8 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean",
# was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization..
default="transducer_stateless/randcombine1_expscale3_rework2d"
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved