Fix bugs; first version that is running successfully.

This commit is contained in:
Daniel Povey 2021-08-23 22:40:23 +08:00
parent c3a8727446
commit 7711fba867

View File

@ -141,7 +141,7 @@ def get_params() -> AttributeDict:
"start_epoch": 0,
"num_epochs": 20,
"num_valid_batches": 100,
"symbols_per_batch": 10000,
"symbols_per_batch": 5000,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -417,13 +417,13 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
f"batch shape: {tuple(batch[0].shape)}")
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
loss_cpu / num_frames_cpu,
params.batch_idx_train,
)
tb_writer.add_scalar(
@ -549,7 +549,7 @@ def run(rank, world_size, args):
collate_fn=collate_fn)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
train_sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None: