mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make warmup period decrease scale on simple loss, leaving pruned loss scale constant.
This commit is contained in:
parent
efde3757c7
commit
1ec9fe5c98
@ -228,13 +228,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def reset(self):
|
||||
for d in self.state.values():
|
||||
# d should be a dict. clear all elements from it.
|
||||
d.clear()
|
||||
|
||||
|
||||
def _init_state(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
@ -906,8 +899,8 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
avg_loss = 0.0
|
||||
for epoch in range(180):
|
||||
scheduler.step_epoch()
|
||||
if epoch == 100 and iter == 1:
|
||||
optim.reset()
|
||||
#if epoch == 100 and iter in [2,3]:
|
||||
# optim.reset_speedup() # check it doesn't crash.
|
||||
|
||||
#if epoch == 130:
|
||||
# opts = diagnostics.TensorDiagnosticOptions(
|
||||
|
||||
@ -415,8 +415,7 @@ def get_params() -> AttributeDict:
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
# parameters for Noam
|
||||
"warm_step": 3000, # arg given to model, not for lrate
|
||||
"warm_step": 2000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -652,18 +651,18 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (warm_step), to avoid
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0 if batch_idx_train < warm_step
|
||||
else 0.1 if batch_idx_train < 2 * warm_step
|
||||
else 1.0
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
simple_loss_scale * simple_loss
|
||||
+ pruned_loss
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
@ -793,14 +792,6 @@ def train_one_epoch(
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
if params.batch_idx_train in [ params.model_warm_step,
|
||||
2 * params.model_warm_step ]:
|
||||
logging.info("Resetting optimizer state due to change in loss definition.")
|
||||
# we're about to start using the pruned loss, or rescale it,
|
||||
# so reset the optimizer state, to avoid
|
||||
# possible instability due to the squared stats becoming
|
||||
# inaccurate (too small)
|
||||
optimizer.reset()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
@ -1043,7 +1034,6 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
@ -1136,7 +1126,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
warmup: float
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
@ -1154,7 +1143,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=warmup,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user