mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 20:22:42 +00:00
Fix errors in madam.py
This commit is contained in:
parent
66467f2da8
commit
b7d4a4f983
@ -855,7 +855,7 @@ class Foam(object):
|
|||||||
min_target_rms: float = 0.05,
|
min_target_rms: float = 0.05,
|
||||||
limit_grad_factor: float = float('inf'),
|
limit_grad_factor: float = float('inf'),
|
||||||
l2_period: int = 1) -> None:
|
l2_period: int = 1) -> None:
|
||||||
"""Construct an Foam object."""
|
"""Construct an Noam object."""
|
||||||
self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
|
self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
|
||||||
min_target_rms=min_target_rms,
|
min_target_rms=min_target_rms,
|
||||||
limit_grad_factor=limit_grad_factor,
|
limit_grad_factor=limit_grad_factor,
|
||||||
@ -933,20 +933,15 @@ class Foam(object):
|
|||||||
"""Return state_dict."""
|
"""Return state_dict."""
|
||||||
return {
|
return {
|
||||||
"_step": self._step,
|
"_step": self._step,
|
||||||
"warmup": self.warmup,
|
|
||||||
"factor": self.factor,
|
|
||||||
"model_size": self.model_size,
|
|
||||||
"_rate": self._rate,
|
|
||||||
"optimizer": self.optimizer.state_dict(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
"""Load state_dict."""
|
"""Load state_dict. This is compatible with reading a Moam state_dict"""
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if key == "optimizer":
|
if key == "optimizer":
|
||||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||||
else:
|
elif key == '_step':
|
||||||
setattr(self, key, value)
|
self._step = value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1096,6 +1091,11 @@ def test_foam():
|
|||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
state_dict = optimizer.state_dict()
|
||||||
|
step = optimizer._step
|
||||||
|
optimizer._step = 0
|
||||||
|
optimizer.load_state_dict(state_dict)
|
||||||
|
assert optimizer._step == step
|
||||||
|
|
||||||
|
|
||||||
def test_to_device():
|
def test_to_device():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user