from local

This commit is contained in:
dohe0342 2023-01-09 21:42:23 +09:00
parent 842844270e
commit ecc9343c94
3 changed files with 2 additions and 2 deletions

View File

@ -838,7 +838,7 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train) #scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
@ -1121,7 +1121,7 @@ def run(rank, world_size, args):
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1): for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1) #scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1) fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1) train_dl.sampler.set_epoch(epoch - 1)