mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Reduce initial scaling of modules
This commit is contained in:
parent
d906bc2a4f
commit
a392cb9fbc
@ -407,12 +407,13 @@ class BasicNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class ScaledLinear(nn.Linear):
|
||||
def __init__(self, *args, scale_speed=5.0, **kwargs):
|
||||
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
|
||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||
self.weight_scale = nn.Parameter(torch.zeros(()))
|
||||
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
self.scale_speed = scale_speed
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(torch.zeros(()))
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
|
||||
@ -431,12 +432,14 @@ class ScaledLinear(nn.Linear):
|
||||
|
||||
|
||||
class ScaledConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, scale_speed = 5.0, **kwargs):
|
||||
def __init__(self, *args, scale_speed = 5.0,
|
||||
initial_scale=1.0, **kwargs):
|
||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||
self.scale_speed = scale_speed
|
||||
self.weight_scale = nn.Parameter(torch.zeros(()))
|
||||
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(torch.zeros(()))
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
|
||||
@ -459,12 +462,13 @@ class ScaledConv1d(nn.Conv1d):
|
||||
|
||||
|
||||
class ScaledConv2d(nn.Conv2d):
|
||||
def __init__(self, *args, scale_speed=5.0, **kwargs):
|
||||
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
|
||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||
self.scale_speed = scale_speed
|
||||
self.weight_scale = nn.Parameter(torch.zeros(()))
|
||||
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(torch.zeros(()))
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter('bias_scale', None)
|
||||
|
||||
|
@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
max_factor=0.01),
|
||||
SwishExpScale(dim_feedforward, speed=20.0),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
|
@ -110,7 +110,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless/randcombine1_expscale3_rework2",
|
||||
default="transducer_stateless/randcombine1_expscale3_rework2b",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
Loading…
x
Reference in New Issue
Block a user