diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 8db8bc920..8c06eb4f9 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -516,6 +516,8 @@ def train_one_epoch( """ model.train() + libri_tot_loss = MetricsTracker() + giga_tot_loss = MetricsTracker() tot_loss = MetricsTracker() # index 0: for LibriSpeech @@ -542,6 +544,8 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) + libri = is_libri(batch["supervisions"]["cut"][0]) + loss, loss_info = compute_loss( params=params, model=model, @@ -551,6 +555,16 @@ def train_one_epoch( ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + if libri: + libri_tot_loss = ( + libri_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "libri" # for logging only + else: + giga_tot_loss = ( + giga_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "giga" # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -563,19 +577,29 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" + f"batch {batch_idx}, {prefix}_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"libri_tot_loss[{libri_tot_loss}], " + f"giga_tot_loss[{giga_tot_loss}], " + f"batch size: {batch_size}" ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train + tb_writer, + f"train/current_{prefix}_", + params.batch_idx_train, ) tot_loss.write_summary( tb_writer, "train/tot_", params.batch_idx_train ) + libri_tot_loss.write_summary( + tb_writer, "train/libri_tot_", params.batch_idx_train + ) + giga_tot_loss.write_summary( + tb_writer, "train/giga_tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -738,7 +762,9 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) - for dl in [train_dl, giga_train_dl]: + # It's time consuming to include `giga_train_dl` here + # for dl in [train_dl, giga_train_dl]: + for dl in [train_dl]: scan_pessimistic_batches_for_oom( model=model, train_dl=dl,