mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reset optimizer state when we change loss function definition.
This commit is contained in:
parent
84580ec022
commit
efde3757c7
@ -228,6 +228,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def reset(self):
|
||||
for d in self.state.values():
|
||||
# d should be a dict. clear all elements from it.
|
||||
d.clear()
|
||||
|
||||
|
||||
def _init_state(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
@ -899,8 +906,8 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
avg_loss = 0.0
|
||||
for epoch in range(180):
|
||||
scheduler.step_epoch()
|
||||
#if epoch == 100 and iter in [2,3]:
|
||||
# optim.reset_speedup() # check it doesn't crash.
|
||||
if epoch == 100 and iter == 1:
|
||||
optim.reset()
|
||||
|
||||
#if epoch == 130:
|
||||
# opts = diagnostics.TensorDiagnosticOptions(
|
||||
|
||||
@ -399,7 +399,8 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
- warm_step: The warmup period that dictates when we introduce the
|
||||
pruned version of the loss.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
@ -415,7 +416,7 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
# parameters for Noam
|
||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||
"warm_step": 3000, # arg given to model, not for lrate
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -603,7 +604,6 @@ def compute_loss(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
@ -636,6 +636,9 @@ def compute_loss(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
@ -650,13 +653,13 @@ def compute_loss(
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
# for the same amount of time (warm_step), to avoid
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0
|
||||
if warmup < 1.0
|
||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
0.0 if batch_idx_train < warm_step
|
||||
else 0.1 if batch_idx_train < 2 * warm_step
|
||||
else 1.0
|
||||
)
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
@ -781,7 +784,6 @@ def train_one_epoch(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
@ -791,15 +793,14 @@ def train_one_epoch(
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
if params.batch_idx_train == params.model_warm_step:
|
||||
# we're about to start using the pruned loss, which brings new
|
||||
# modules into play, so reset the frequencies of update, to
|
||||
# avoid possible instability.
|
||||
try:
|
||||
optimizer.reset_speedup()
|
||||
logging.info("Reset speedup on optimizer")
|
||||
except:
|
||||
pass
|
||||
if params.batch_idx_train in [ params.model_warm_step,
|
||||
2 * params.model_warm_step ]:
|
||||
logging.info("Resetting optimizer state due to change in loss definition.")
|
||||
# we're about to start using the pruned loss, or rescale it,
|
||||
# so reset the optimizer state, to avoid
|
||||
# possible instability due to the squared stats becoming
|
||||
# inaccurate (too small)
|
||||
optimizer.reset()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user