mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Reduce initial scaling of modules
This commit is contained in:
parent
d906bc2a4f
commit
a392cb9fbc
@ -407,12 +407,13 @@ class BasicNorm(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ScaledLinear(nn.Linear):
|
class ScaledLinear(nn.Linear):
|
||||||
def __init__(self, *args, scale_speed=5.0, **kwargs):
|
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
|
||||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||||
self.weight_scale = nn.Parameter(torch.zeros(()))
|
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
||||||
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
self.scale_speed = scale_speed
|
self.scale_speed = scale_speed
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias_scale = nn.Parameter(torch.zeros(()))
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter('bias_scale', None)
|
||||||
|
|
||||||
@ -431,12 +432,14 @@ class ScaledLinear(nn.Linear):
|
|||||||
|
|
||||||
|
|
||||||
class ScaledConv1d(nn.Conv1d):
|
class ScaledConv1d(nn.Conv1d):
|
||||||
def __init__(self, *args, scale_speed = 5.0, **kwargs):
|
def __init__(self, *args, scale_speed = 5.0,
|
||||||
|
initial_scale=1.0, **kwargs):
|
||||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||||
self.scale_speed = scale_speed
|
self.scale_speed = scale_speed
|
||||||
self.weight_scale = nn.Parameter(torch.zeros(()))
|
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
||||||
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias_scale = nn.Parameter(torch.zeros(()))
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter('bias_scale', None)
|
||||||
|
|
||||||
@ -459,12 +462,13 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
|
|
||||||
|
|
||||||
class ScaledConv2d(nn.Conv2d):
|
class ScaledConv2d(nn.Conv2d):
|
||||||
def __init__(self, *args, scale_speed=5.0, **kwargs):
|
def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs):
|
||||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||||
self.scale_speed = scale_speed
|
self.scale_speed = scale_speed
|
||||||
self.weight_scale = nn.Parameter(torch.zeros(()))
|
initial_scale = (torch.tensor(initial_scale).log() / scale_speed)
|
||||||
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias_scale = nn.Parameter(torch.zeros(()))
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias_scale', None)
|
self.register_parameter('bias_scale', None)
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
max_factor=0.01),
|
max_factor=0.01),
|
||||||
SwishExpScale(dim_feedforward, speed=20.0),
|
SwishExpScale(dim_feedforward, speed=20.0),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2",
|
default="transducer_stateless/randcombine1_expscale3_rework2b",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user