Fix errors in madam.py

This commit is contained in:
Fangjun Kuang 2021-08-26 22:28:18 +08:00
parent 66467f2da8
commit b7d4a4f983

View File

@ -855,7 +855,7 @@ class Foam(object):
min_target_rms: float = 0.05,
limit_grad_factor: float = float('inf'),
l2_period: int = 1) -> None:
"""Construct an Foam object."""
"""Construct an Noam object."""
self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
min_target_rms=min_target_rms,
limit_grad_factor=limit_grad_factor,
@ -933,20 +933,15 @@ class Foam(object):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
"""Load state_dict. This is compatible with reading a Moam state_dict"""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)
elif key == '_step':
self._step = value
@ -1096,6 +1091,11 @@ def test_foam():
model.zero_grad()
print("")
state_dict = optimizer.state_dict()
step = optimizer._step
optimizer._step = 0
optimizer.load_state_dict(state_dict)
assert optimizer._step == step
def test_to_device():