diff --git a/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py b/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py index bba17c375..019dbbe43 100644 --- a/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py +++ b/egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py @@ -855,7 +855,7 @@ class Foam(object): min_target_rms: float = 0.05, limit_grad_factor: float = float('inf'), l2_period: int = 1) -> None: - """Construct an Foam object.""" + """Construct an Noam object.""" self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9, min_target_rms=min_target_rms, limit_grad_factor=limit_grad_factor, @@ -933,20 +933,15 @@ class Foam(object): """Return state_dict.""" return { "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), } def load_state_dict(self, state_dict): - """Load state_dict.""" + """Load state_dict. This is compatible with reading a Moam state_dict""" for key, value in state_dict.items(): if key == "optimizer": self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) + elif key == '_step': + self._step = value @@ -1096,6 +1091,11 @@ def test_foam(): model.zero_grad() print("") + state_dict = optimizer.state_dict() + step = optimizer._step + optimizer._step = 0 + optimizer.load_state_dict(state_dict) + assert optimizer._step == step def test_to_device():