diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 19c174d39..20a111dd5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -28,9 +28,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 300 - - - """ @@ -938,14 +935,15 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup = 0.0 - ) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup = 0.0 + ) loss.backward() optimizer.step() optimizer.zero_grad()