mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
fix error in accum_grad (#1693)
This commit is contained in:
parent
2e13298717
commit
11151415f3
@ -948,7 +948,7 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
optimizer.zero_grad()
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
@ -948,7 +948,7 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
optimizer.zero_grad()
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
@ -774,7 +774,7 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
optimizer.zero_grad()
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
@ -774,7 +774,7 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
optimizer.zero_grad()
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
@ -1245,7 +1245,7 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
optimizer.zero_grad()
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
@ -1072,7 +1072,7 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||
optimizer.zero_grad()
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
Loading…
x
Reference in New Issue
Block a user