diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 85d671ecc..72bc654ea 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -141,7 +141,7 @@ def get_params() -> AttributeDict: "start_epoch": 0, "num_epochs": 20, "num_valid_batches": 100, - "symbols_per_batch": 10000, + "symbols_per_batch": 5000, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -417,13 +417,13 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, " f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" - ) + f"batch shape: {tuple(batch[0].shape)}") + if tb_writer is not None: tb_writer.add_scalar( "train/current_loss", - loss_cpu / params.train_frames, + loss_cpu / num_frames_cpu, params.batch_idx_train, ) tb_writer.add_scalar( @@ -549,7 +549,7 @@ def run(rank, world_size, args): collate_fn=collate_fn) for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) + train_sampler.set_epoch(epoch) cur_lr = optimizer._rate if tb_writer is not None: