mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make warmup period decrease scale on simple loss, leaving pruned loss scale constant.
This commit is contained in:
parent
efde3757c7
commit
1ec9fe5c98
@ -228,13 +228,6 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def reset(self):
|
|
||||||
for d in self.state.values():
|
|
||||||
# d should be a dict. clear all elements from it.
|
|
||||||
d.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _init_state(self,
|
def _init_state(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
@ -906,8 +899,8 @@ def _test_scaled_adam(hidden_dim: int):
|
|||||||
avg_loss = 0.0
|
avg_loss = 0.0
|
||||||
for epoch in range(180):
|
for epoch in range(180):
|
||||||
scheduler.step_epoch()
|
scheduler.step_epoch()
|
||||||
if epoch == 100 and iter == 1:
|
#if epoch == 100 and iter in [2,3]:
|
||||||
optim.reset()
|
# optim.reset_speedup() # check it doesn't crash.
|
||||||
|
|
||||||
#if epoch == 130:
|
#if epoch == 130:
|
||||||
# opts = diagnostics.TensorDiagnosticOptions(
|
# opts = diagnostics.TensorDiagnosticOptions(
|
||||||
|
|||||||
@ -415,8 +415,7 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
# parameters for Noam
|
"warm_step": 2000,
|
||||||
"warm_step": 3000, # arg given to model, not for lrate
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -652,18 +651,18 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
|
||||||
# for the same amount of time (warm_step), to avoid
|
s = params.simple_loss_scale
|
||||||
# overwhelming the simple_loss and causing it to diverge,
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# in case it had not fully learned the alignment yet.
|
# to params.simple_loss scale by warm_step.
|
||||||
pruned_loss_scale = (
|
simple_loss_scale = (
|
||||||
0.0 if batch_idx_train < warm_step
|
s if batch_idx_train >= warm_step
|
||||||
else 0.1 if batch_idx_train < 2 * warm_step
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
else 1.0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = (
|
loss = (
|
||||||
params.simple_loss_scale * simple_loss
|
simple_loss_scale * simple_loss
|
||||||
+ pruned_loss_scale * pruned_loss
|
+ pruned_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
@ -793,14 +792,6 @@ def train_one_epoch(
|
|||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
|
||||||
if params.batch_idx_train in [ params.model_warm_step,
|
|
||||||
2 * params.model_warm_step ]:
|
|
||||||
logging.info("Resetting optimizer state due to change in loss definition.")
|
|
||||||
# we're about to start using the pruned loss, or rescale it,
|
|
||||||
# so reset the optimizer state, to avoid
|
|
||||||
# possible instability due to the squared stats becoming
|
|
||||||
# inaccurate (too small)
|
|
||||||
optimizer.reset()
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -1043,7 +1034,6 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
@ -1136,7 +1126,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
warmup: float
|
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1154,7 +1143,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=warmup,
|
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user