diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 6f19807dc..17450def8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -140,7 +140,8 @@ class Eve(Optimizer): # Suppose we are going to shrinkage with a small value epsilon (not the # same as the eps above!), i.e. param *= (1-epsilon). Then - # if E[param_elem^2] == target_rms^2, + # if E[param_elem^2] == target_rms^2 (because we desire equilibrium when + # the RMS of the parameters equals target_rms), it follows that # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2), # which we can put as: # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 33b4ad908..4b91bb04c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -161,7 +161,7 @@ class ScaledLinear(nn.Linear): # we plan to use Eve as the optimizer, which will eventually make the stddev approach # 0.1 as that's the target_rms we set, but we initialize with a larger stddev # to have the same effect as a warm-up period. - std = 0.5 / initial_speed + std = 0.25 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -199,7 +199,7 @@ class ScaledConv1d(nn.Conv1d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.5 / initial_speed + std = 0.25 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -244,7 +244,7 @@ class ScaledConv2d(nn.Conv2d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.5 / initial_speed + std = 0.25 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -480,7 +480,7 @@ class ScaledEmbedding(nn.Module): def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.5 / initial_speed + std = 0.25 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log())