Make warmup period decrease scale on simple loss, leaving pruned loss scale constant.

This commit is contained in:
Daniel Povey 2022-10-22 14:48:53 +08:00
parent efde3757c7
commit 1ec9fe5c98
2 changed files with 13 additions and 32 deletions

View File

@ -228,13 +228,6 @@ class ScaledAdam(BatchedOptimizer):
return loss
@torch.no_grad()
def reset(self):
for d in self.state.values():
# d should be a dict. clear all elements from it.
d.clear()
def _init_state(self,
group: dict,
p: Tensor,
@ -906,8 +899,8 @@ def _test_scaled_adam(hidden_dim: int):
avg_loss = 0.0
for epoch in range(180):
scheduler.step_epoch()
if epoch == 100 and iter == 1:
optim.reset()
#if epoch == 100 and iter in [2,3]:
# optim.reset_speedup() # check it doesn't crash.
#if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(

View File

@ -415,8 +415,7 @@ def get_params() -> AttributeDict:
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
# parameters for Noam
"warm_step": 3000, # arg given to model, not for lrate
"warm_step": 2000,
"env_info": get_env_info(),
}
)
@ -652,18 +651,18 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0 if batch_idx_train < warm_step
else 0.1 if batch_idx_train < 2 * warm_step
else 1.0
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss
)
assert loss.requires_grad == is_training
@ -793,14 +792,6 @@ def train_one_epoch(
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
if params.batch_idx_train in [ params.model_warm_step,
2 * params.model_warm_step ]:
logging.info("Resetting optimizer state due to change in loss definition.")
# we're about to start using the pruned loss, or rescale it,
# so reset the optimizer state, to avoid
# possible instability due to the squared stats becoming
# inaccurate (too small)
optimizer.reset()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
@ -1043,7 +1034,6 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1136,7 +1126,6 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float
):
from lhotse.dataset import find_pessimistic_batches
@ -1154,7 +1143,6 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=warmup,
)
loss.backward()
optimizer.step()