mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make sure param_rms limit is effectively applied; fix tests in optim.py
This commit is contained in:
parent
c6bad1ee4f
commit
8056e0f9af
@ -561,8 +561,11 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
# when the param gets too small, just don't shrink it any further.
|
||||
scale_step.masked_fill_(is_too_small, 0.0)
|
||||
# when it gets too large, stop it from getting any larger.
|
||||
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
||||
|
||||
# and ensure the parameter rms after update never exceeds param_max_rms.
|
||||
scale_step = torch.minimum(scale_step,
|
||||
(param_max_rms - param_rms) / param_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
delta.add_(p * scale_step, alpha=(1 - beta1))
|
||||
@ -776,7 +779,9 @@ class Eden(LRScheduler):
|
||||
|
||||
def _test_eden():
|
||||
m = torch.nn.Linear(100, 100)
|
||||
optim = ScaledAdam(m.parameters(), lr=0.03)
|
||||
parameters_names = [ [ x[0] for x in m.named_parameters() ] ]
|
||||
optim = ScaledAdam(m.parameters(), lr=0.03,
|
||||
parameters_names=parameters_names)
|
||||
|
||||
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
||||
|
||||
@ -987,7 +992,9 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
if iter == 0:
|
||||
optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1:
|
||||
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
||||
parameters_names = [ [ x[0] for x in m.named_parameters() ] ]
|
||||
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0,
|
||||
parameters_names=parameters_names)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user