mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add logging about memory used.
This commit is contained in:
parent
6a6df19bde
commit
78f3cba58c
@ -558,10 +558,12 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
||||
x_sign = x.sign()
|
||||
over_limit = (x.abs() - limit) > 0
|
||||
# The following is a memory efficient way to penalize the absolute values of
|
||||
# x that's over the limit. the numerical value of aux_loss as computed here will actually be
|
||||
# larger than it should be, but it has the same derivative as
|
||||
# penalty * (x.abs() - limit).relu()
|
||||
# which is what we really want to penalize
|
||||
# x that's over the limit. (The memory efficiency comes when you think
|
||||
# about which items torch needs to cache for the autograd, and which ones it
|
||||
# can throw away). The numerical value of aux_loss as computed here will
|
||||
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
||||
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
||||
# limit).relu().
|
||||
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
||||
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
||||
# sum() due to how with_loss() works.
|
||||
|
||||
@ -901,6 +901,7 @@ def train_one_epoch(
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user