Reduce initial scaling of modules

This commit is contained in:
Daniel Povey 2022-03-12 16:53:03 +08:00
parent d906bc2a4f
commit a392cb9fbc
3 changed files with 15 additions and 11 deletions

View File

@ -407,12 +407,13 @@ class BasicNorm(torch.nn.Module):
class ScaledLinear(nn.Linear): class ScaledLinear(nn.Linear):
def __init__(self, *args, scale_speed=5.0, **kwargs): def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
super(ScaledLinear, self).__init__(*args, **kwargs) super(ScaledLinear, self).__init__(*args, **kwargs)
self.weight_scale = nn.Parameter(torch.zeros(())) initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
self.scale_speed = scale_speed self.scale_speed = scale_speed
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(torch.zeros(())) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter('bias_scale', None)
@ -431,12 +432,14 @@ class ScaledLinear(nn.Linear):
class ScaledConv1d(nn.Conv1d): class ScaledConv1d(nn.Conv1d):
def __init__(self, *args, scale_speed = 5.0, **kwargs): def __init__(self, *args, scale_speed = 5.0,
initial_scale=1.0, **kwargs):
super(ScaledConv1d, self).__init__(*args, **kwargs) super(ScaledConv1d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed self.scale_speed = scale_speed
self.weight_scale = nn.Parameter(torch.zeros(())) initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(torch.zeros(())) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter('bias_scale', None)
@ -459,12 +462,13 @@ class ScaledConv1d(nn.Conv1d):
class ScaledConv2d(nn.Conv2d): class ScaledConv2d(nn.Conv2d):
def __init__(self, *args, scale_speed=5.0, **kwargs): def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
super(ScaledConv2d, self).__init__(*args, **kwargs) super(ScaledConv2d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed self.scale_speed = scale_speed
self.weight_scale = nn.Parameter(torch.zeros(())) initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None: if self.bias is not None:
self.bias_scale = nn.Parameter(torch.zeros(())) self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else: else:
self.register_parameter('bias_scale', None) self.register_parameter('bias_scale', None)

View File

@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module):
max_factor=0.01), max_factor=0.01),
SwishExpScale(dim_feedforward, speed=20.0), SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
) )
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/randcombine1_expscale3_rework2", default="transducer_stateless/randcombine1_expscale3_rework2b",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved