mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Fix bugs; first version that is running successfully.
This commit is contained in:
parent
c3a8727446
commit
7711fba867
@ -141,7 +141,7 @@ def get_params() -> AttributeDict:
|
||||
"start_epoch": 0,
|
||||
"num_epochs": 20,
|
||||
"num_valid_batches": 100,
|
||||
"symbols_per_batch": 10000,
|
||||
"symbols_per_batch": 5000,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
@ -417,13 +417,13 @@ def train_one_epoch(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
)
|
||||
f"batch shape: {tuple(batch[0].shape)}")
|
||||
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / params.train_frames,
|
||||
loss_cpu / num_frames_cpu,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
@ -549,7 +549,7 @@ def run(rank, world_size, args):
|
||||
collate_fn=collate_fn)
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user