from local

This commit is contained in:
dohe0342 2022-12-10 14:01:43 +09:00
parent c7bfa2c95a
commit b566f3b925
2 changed files with 15 additions and 5 deletions

View File

@ -958,12 +958,22 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)
if params.multi_optim and batch_idx % params.accum_grads == 0:
set_batch_count(model, params.batch_idx_train)
scheduler_enc.step_batch(params.batch_idx_train)
scheduler_dec.step_batch(params.batch_idx_train)
scaler.step(optimizer_enc)
scaler.step(optimizer_dec)
scaler.update()
optimizer_enc.zero_grad()
optimizer_dec.zero_grad()
elif not params.multi_optim and batch_idx % params.accum_grads == 0:
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
raise