Rework of initialization

This commit is contained in:
Daniel Povey 2022-03-16 12:42:59 +08:00
parent 1331199530
commit 633213424d
4 changed files with 78 additions and 75 deletions

View File

@ -62,13 +62,6 @@ class Conv2dSubsampling(nn.Module):
self.out_norm = BasicNorm(odim, learn_eps=False) self.out_norm = BasicNorm(odim, learn_eps=False)
# constrain median of output to be close to zero. # constrain median of output to be close to zero.
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
self._reset_parameters()
def _reset_parameters(self):
# init weights with smaller than default variance, because otherwise
# they learn too slowly in relative terms (assuming we're training with adam).
nn.init.normal_(self.conv[0].weight, std=0.05)
nn.init.constant_(self.conv[0].bias, 0.0)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -406,8 +399,36 @@ class BasicNorm(torch.nn.Module):
return x * scales return x * scales
class ScaledLinear(nn.Linear): class ScaledLinear(nn.Linear):
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): """
A modified version of nn.Linear where the parameters are scaled before
use, via:
weight = self.weight * (self.weight_scale * self.scale_speed).exp()
bias = self.bias * (self.bias_scale * self.scale_speed).exp()
Args:
Accepts the standard args and kwargs that nn.Linear accepts
e.g. in_features, out_features, bias=False.
scale_speed: a factor that affects how fast the weight_scale
and bias_scale learn; this value is suitable for Adam-type
optimizers.
initial_scale: you can override this if you want to increase
or decrease the initial magnitude of the module's output
(affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is
to re-initialize the parameters.
Note: it uses the default initialization for the weight and bias,
inherited from nn.Linear. For modules with small fan-in, this
may be larger than optimal.
"""
def __init__(self, *args,
scale_speed: float = 5.0,
initial_scale: float = 1.0,
**kwargs):
super(ScaledLinear, self).__init__(*args, **kwargs) super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = (torch.tensor(initial_scale).log() / scale_speed) initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.weight_scale = nn.Parameter(initial_scale.clone().detach())
@ -417,6 +438,17 @@ class ScaledLinear(nn.Linear):
else: else:
self.register_parameter('bias_scale', None) self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
def _reset_parameters(self):
nn.init.normal_(self.weight, std=0.05)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1]
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
def get_weight(self): def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp() return self.weight * (self.weight_scale * self.scale_speed).exp()
@ -425,7 +457,6 @@ class ScaledLinear(nn.Linear):
return (None if self.bias is None else return (None if self.bias is None else
self.bias * (self.bias_scale * self.scale_speed).exp()) self.bias * (self.bias_scale * self.scale_speed).exp())
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(input, self.get_weight(), return torch.nn.functional.linear(input, self.get_weight(),
self.get_bias()) self.get_bias())
@ -442,6 +473,17 @@ class ScaledConv1d(nn.Conv1d):
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class
def _reset_parameters(self):
nn.init.normal_(self.weight, std=0.05)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
def get_weight(self): def get_weight(self):
return self.weight * (self.weight_scale * self.scale_speed).exp() return self.weight * (self.weight_scale * self.scale_speed).exp()
@ -471,6 +513,16 @@ class ScaledConv2d(nn.Conv2d):
self.bias_scale = nn.Parameter(initial_scale.clone().detach()) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter('bias_scale', None)
self._reset_parameters() # Overrides the reset_parameters in base class
def _reset_parameters(self):
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
nn.init.normal_(self.weight, std=0.05)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
def get_weight(self): def get_weight(self):

View File

@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module):
DerivBalancer(channel_dim=-1), DerivBalancer(channel_dim=-1),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ScaledLinear(dim_feedforward, d_model),
) )
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
@ -170,7 +170,7 @@ class ConformerEncoderLayer(nn.Module):
DerivBalancer(channel_dim=-1), DerivBalancer(channel_dim=-1),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ScaledLinear(dim_feedforward, d_model),
) )
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
@ -423,7 +423,7 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding. # linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
@ -434,7 +434,6 @@ class RelPositionMultiheadAttention(nn.Module):
self.scale_speed = scale_speed self.scale_speed = scale_speed
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters() self._reset_parameters()
def _pos_bias_u(self): def _pos_bias_u(self):
@ -444,12 +443,8 @@ class RelPositionMultiheadAttention(nn.Module):
return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp()
def _reset_parameters(self) -> None: def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight) nn.init.normal_(self.pos_bias_u, std=0.05)
nn.init.constant_(self.in_proj.bias, 0.0) nn.init.normal_(self.pos_bias_v, std=0.05)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def forward( def forward(
self, self,
@ -891,7 +886,6 @@ class ConvolutionModule(nn.Module):
stride=1, stride=1,
padding=0, padding=0,
bias=bias, bias=bias,
initial_scale=0.25
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:

View File

@ -183,7 +183,7 @@ class ScaledEmbedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False, scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None, sparse: bool = False,
scale_speed: float = 5.0) -> None: scale_speed: float = 5.0) -> None:
super(ScaledEmbedding, self).__init__() super(ScaledEmbedding, self).__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
@ -198,19 +198,18 @@ class ScaledEmbedding(nn.Module):
self.scale_grad_by_freq = scale_grad_by_freq self.scale_grad_by_freq = scale_grad_by_freq
self.scale_speed = scale_speed self.scale_speed = scale_speed
self.scale = nn.Parameter(torch.tensor(embedding_dim**0.5).log() / scale_speed) self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
if _weight is None:
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
'Shape of weight does not match num_embeddings and embedding_dim'
self.weight = nn.Parameter(_weight)
self.sparse = sparse self.sparse = sparse
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=self.embedding_dim**-0.5) nn.init.normal_(self.weight, std=0.05)
nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed)
if self.padding_idx is not None: if self.padding_idx is not None:
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
@ -228,7 +227,6 @@ class ScaledEmbedding(nn.Module):
None, 2.0, # None, 2.0 relates to normalization None, 2.0, # None, 2.0 relates to normalization
self.scale_grad_by_freq, self.sparse) self.scale_grad_by_freq, self.sparse)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}'
if self.padding_idx is not None: if self.padding_idx is not None:
@ -238,45 +236,3 @@ class ScaledEmbedding(nn.Module):
if self.sparse is not False: if self.sparse is not False:
s += ', sparse=True' s += ', sparse=True'
return s.format(**self.__dict__) return s.format(**self.__dict__)
@classmethod
def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
max_norm=None, norm_type=2., scale_grad_by_freq=False,
sparse=False):
r"""Creates Embedding instance from given 2-dimensional FloatTensor.
Args:
embeddings (Tensor): FloatTensor containing weights for the Embedding.
First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
padding_idx (int, optional): See module initialization documentation.
max_norm (float, optional): See module initialization documentation.
norm_type (float, optional): See module initialization documentation. Default ``2``.
scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
sparse (bool, optional): See module initialization documentation.
Examples::
>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([1])
>>> embedding(input)
tensor([[ 4.0000, 5.1000, 6.3000]])
"""
assert embeddings.dim() == 2, \
'Embeddings parameter is expected to be 2-dimensional'
rows, cols = embeddings.shape
embedding = cls(
num_embeddings=rows,
embedding_dim=cols,
_weight=embeddings,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
embedding.weight.requires_grad = not freeze
return embedding

View File

@ -110,7 +110,8 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean", # was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization..
default="transducer_stateless/randcombine1_expscale3_rework2d"
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved