From 34c380cd8bd5c69bf12c09c262f509c5bc761ca1 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 5 Apr 2022 07:50:52 +0800 Subject: [PATCH] For older pytorch version --- egs/librispeech/ASR/pruned_transducer_stateless/model.py | 4 ++-- egs/librispeech/ASR/pruned_transducer_stateless/train.py | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 810e7ace5..468a126fb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -128,7 +128,7 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens - with torch.autocast(device_type=x.device.type, enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=decoder_out.float(), am=encoder_out.float(), @@ -158,7 +158,7 @@ class Transducer(nn.Module): # logits : [B, T, prune_range, C] logits = self.joiner(am_pruned, lm_pruned) - with torch.autocast(device_type=x.device.type, enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index dad2b2b18..0b2025070 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -647,9 +647,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.autocast( - device_type=model.device.type, enabled=params.use_fp16 - ): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -920,9 +918,7 @@ def scan_pessimistic_batches_for_oom( batch = train_dl.dataset[cuts] try: optimizer.zero_grad() - with torch.autocast( - device_type=model.device.type, enabled=params.use_fp16 - ): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model,