Reset optimizer state when we change loss function definition.

This commit is contained in:
Daniel Povey 2022-10-22 14:30:18 +08:00
parent 84580ec022
commit efde3757c7
2 changed files with 27 additions and 19 deletions

View File

@ -228,6 +228,13 @@ class ScaledAdam(BatchedOptimizer):
return loss return loss
@torch.no_grad()
def reset(self):
for d in self.state.values():
# d should be a dict. clear all elements from it.
d.clear()
def _init_state(self, def _init_state(self,
group: dict, group: dict,
p: Tensor, p: Tensor,
@ -899,8 +906,8 @@ def _test_scaled_adam(hidden_dim: int):
avg_loss = 0.0 avg_loss = 0.0
for epoch in range(180): for epoch in range(180):
scheduler.step_epoch() scheduler.step_epoch()
#if epoch == 100 and iter in [2,3]: if epoch == 100 and iter == 1:
# optim.reset_speedup() # check it doesn't crash. optim.reset()
#if epoch == 130: #if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions( # opts = diagnostics.TensorDiagnosticOptions(

View File

@ -399,7 +399,8 @@ def get_params() -> AttributeDict:
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warm_step for Noam optimizer. - warm_step: The warmup period that dictates when we introduce the
pruned version of the loss.
""" """
params = AttributeDict( params = AttributeDict(
{ {
@ -415,7 +416,7 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate "warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -603,7 +604,6 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -636,6 +636,9 @@ def compute_loss(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train
warm_step = params.warm_step
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
@ -650,13 +653,13 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
# after the main warmup step, we keep pruned_loss_scale small # after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 0.0 if batch_idx_train < warm_step
if warmup < 1.0 else 0.1 if batch_idx_train < 2 * warm_step
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) else 1.0
) )
loss = ( loss = (
params.simple_loss_scale * simple_loss params.simple_loss_scale * simple_loss
@ -781,7 +784,6 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -791,15 +793,14 @@ def train_one_epoch(
scaler.scale(loss).backward() scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
if params.batch_idx_train == params.model_warm_step: if params.batch_idx_train in [ params.model_warm_step,
# we're about to start using the pruned loss, which brings new 2 * params.model_warm_step ]:
# modules into play, so reset the frequencies of update, to logging.info("Resetting optimizer state due to change in loss definition.")
# avoid possible instability. # we're about to start using the pruned loss, or rescale it,
try: # so reset the optimizer state, to avoid
optimizer.reset_speedup() # possible instability due to the squared stats becoming
logging.info("Reset speedup on optimizer") # inaccurate (too small)
except: optimizer.reset()
pass
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()