diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 698db781c..a2200c04b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -558,10 +558,12 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: x_sign = x.sign() over_limit = (x.abs() - limit) > 0 # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. the numerical value of aux_loss as computed here will actually be - # larger than it should be, but it has the same derivative as - # penalty * (x.abs() - limit).relu() - # which is what we really want to penalize + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) # note: we don't do sum() here on aux)_loss, but it's as if we had done # sum() due to how with_loss() works. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b00074051..66c25831f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -901,6 +901,7 @@ def train_one_epoch( ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train