Reduce initial scaling of modules

This commit is contained in:
Daniel Povey 2022-03-12 16:53:03 +08:00
parent d906bc2a4f
commit a392cb9fbc
3 changed files with 15 additions and 11 deletions

View File

@ -407,12 +407,13 @@ class BasicNorm(torch.nn.Module):
class ScaledLinear(nn.Linear):
def __init__(self, *args, scale_speed=5.0, **kwargs):
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
super(ScaledLinear, self).__init__(*args, **kwargs)
self.weight_scale = nn.Parameter(torch.zeros(()))
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
self.scale_speed = scale_speed
if self.bias is not None:
self.bias_scale = nn.Parameter(torch.zeros(()))
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
@ -431,12 +432,14 @@ class ScaledLinear(nn.Linear):
class ScaledConv1d(nn.Conv1d):
def __init__(self, *args, scale_speed = 5.0, **kwargs):
def __init__(self, *args, scale_speed = 5.0,
initial_scale=1.0, **kwargs):
super(ScaledConv1d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed
self.weight_scale = nn.Parameter(torch.zeros(()))
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(torch.zeros(()))
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)
@ -459,12 +462,13 @@ class ScaledConv1d(nn.Conv1d):
class ScaledConv2d(nn.Conv2d):
def __init__(self, *args, scale_speed=5.0, **kwargs):
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
super(ScaledConv2d, self).__init__(*args, **kwargs)
self.scale_speed = scale_speed
self.weight_scale = nn.Parameter(torch.zeros(()))
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(torch.zeros(()))
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
else:
self.register_parameter('bias_scale', None)

View File

@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module):
max_factor=0.01),
SwishExpScale(dim_feedforward, speed=20.0),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.feed_forward_macaron = nn.Sequential(

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/randcombine1_expscale3_rework2",
default="transducer_stateless/randcombine1_expscale3_rework2b",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved