diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 831537d79..dab0e1e1d 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -407,12 +407,13 @@ class BasicNorm(torch.nn.Module): class ScaledLinear(nn.Linear): - def __init__(self, *args, scale_speed=5.0, **kwargs): + def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) - self.weight_scale = nn.Parameter(torch.zeros(())) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.scale_speed = scale_speed if self.bias is not None: - self.bias_scale = nn.Parameter(torch.zeros(())) + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) @@ -431,12 +432,14 @@ class ScaledLinear(nn.Linear): class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, scale_speed = 5.0, **kwargs): + def __init__(self, *args, scale_speed = 5.0, + initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) self.scale_speed = scale_speed - self.weight_scale = nn.Parameter(torch.zeros(())) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: - self.bias_scale = nn.Parameter(torch.zeros(())) + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) @@ -459,12 +462,13 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, scale_speed=5.0, **kwargs): + def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): super(ScaledConv2d, self).__init__(*args, **kwargs) self.scale_speed = scale_speed - self.weight_scale = nn.Parameter(torch.zeros(())) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: - self.bias_scale = nn.Parameter(torch.zeros(())) + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 93f7dd170..aa35f5e7e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module): max_factor=0.01), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.feed_forward_macaron = nn.Sequential( diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 1a57d654f..b871efd13 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2", + default="transducer_stateless/randcombine1_expscale3_rework2b", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved