mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Update train.py
This commit is contained in:
parent
1b854e5c44
commit
ef0b6df8f8
@ -677,7 +677,7 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=(params.batch_idx_train / params.model_warm_step)
|
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -727,7 +727,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
@ -959,7 +959,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=0.0
|
warmup=0.0,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user