mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
fix error in accum_grad (#1693)
This commit is contained in:
parent
2e13298717
commit
11151415f3
@ -948,7 +948,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
@ -948,7 +948,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
@ -774,7 +774,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
@ -774,7 +774,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
@ -1245,7 +1245,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
@ -1072,7 +1072,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user