fix error in accum_grad (#1693)

This commit is contained in:
zzasdf 2024-07-17 17:47:43 +08:00 committed by GitHub
parent 2e13298717
commit 11151415f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 6 additions and 6 deletions

View File

@ -948,7 +948,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value

View File

@ -948,7 +948,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value

View File

@ -774,7 +774,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value

View File

@ -774,7 +774,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value

View File

@ -1245,7 +1245,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value

View File

@ -1072,7 +1072,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value