mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reset optimizer state when we change loss function definition.
This commit is contained in:
parent
84580ec022
commit
efde3757c7
@ -228,6 +228,13 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def reset(self):
|
||||||
|
for d in self.state.values():
|
||||||
|
# d should be a dict. clear all elements from it.
|
||||||
|
d.clear()
|
||||||
|
|
||||||
|
|
||||||
def _init_state(self,
|
def _init_state(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
@ -899,8 +906,8 @@ def _test_scaled_adam(hidden_dim: int):
|
|||||||
avg_loss = 0.0
|
avg_loss = 0.0
|
||||||
for epoch in range(180):
|
for epoch in range(180):
|
||||||
scheduler.step_epoch()
|
scheduler.step_epoch()
|
||||||
#if epoch == 100 and iter in [2,3]:
|
if epoch == 100 and iter == 1:
|
||||||
# optim.reset_speedup() # check it doesn't crash.
|
optim.reset()
|
||||||
|
|
||||||
#if epoch == 130:
|
#if epoch == 130:
|
||||||
# opts = diagnostics.TensorDiagnosticOptions(
|
# opts = diagnostics.TensorDiagnosticOptions(
|
||||||
|
|||||||
@ -399,7 +399,8 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
|
|
||||||
- warm_step: The warm_step for Noam optimizer.
|
- warm_step: The warmup period that dictates when we introduce the
|
||||||
|
pruned version of the loss.
|
||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
@ -415,7 +416,7 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -603,7 +604,6 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
warmup: float = 1.0,
|
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
@ -636,6 +636,9 @@ def compute_loss(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
batch_idx_train = params.batch_idx_train
|
||||||
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
@ -650,13 +653,13 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (warm_step), to avoid
|
||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
0.0
|
0.0 if batch_idx_train < warm_step
|
||||||
if warmup < 1.0
|
else 0.1 if batch_idx_train < 2 * warm_step
|
||||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
else 1.0
|
||||||
)
|
)
|
||||||
loss = (
|
loss = (
|
||||||
params.simple_loss_scale * simple_loss
|
params.simple_loss_scale * simple_loss
|
||||||
@ -781,7 +784,6 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -791,15 +793,14 @@ def train_one_epoch(
|
|||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
|
||||||
if params.batch_idx_train == params.model_warm_step:
|
if params.batch_idx_train in [ params.model_warm_step,
|
||||||
# we're about to start using the pruned loss, which brings new
|
2 * params.model_warm_step ]:
|
||||||
# modules into play, so reset the frequencies of update, to
|
logging.info("Resetting optimizer state due to change in loss definition.")
|
||||||
# avoid possible instability.
|
# we're about to start using the pruned loss, or rescale it,
|
||||||
try:
|
# so reset the optimizer state, to avoid
|
||||||
optimizer.reset_speedup()
|
# possible instability due to the squared stats becoming
|
||||||
logging.info("Reset speedup on optimizer")
|
# inaccurate (too small)
|
||||||
except:
|
optimizer.reset()
|
||||||
pass
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user