From b29ab59cceeea8e415a31d056368d4f14dad1897 Mon Sep 17 00:00:00 2001 From: Li Peng Date: Tue, 3 Dec 2024 14:37:14 +0800 Subject: [PATCH] Fix bugs introduced by previous commits Along with reformatting to pass black lint. - egs/libritts/ASR/zipformer/train.py - egs/libritts/CODEC/encodec/encodec.py - egs/ljspeech/TTS/vits/vits.py - egs/wenetspeech4tts/TTS/valle/train.py --- egs/libritts/ASR/zipformer/train.py | 8 ++++---- egs/libritts/CODEC/encodec/encodec.py | 2 +- egs/ljspeech/TTS/vits/vits.py | 2 +- egs/wenetspeech4tts/TTS/valle/train.py | 4 +++- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py index 96c45385f..78e3330bd 100755 --- a/egs/libritts/ASR/zipformer/train.py +++ b/egs/libritts/ASR/zipformer/train.py @@ -1049,8 +1049,8 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", - enabled=params.use_autocast, dtype=params.dtype + with torch.amp.autocast( + "cuda", enabled=params.use_autocast, dtype=params.dtype ): loss, loss_info = compute_loss( params=params, @@ -1478,8 +1478,8 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", - enabled=params.use_autocast, dtype=params.dtype + with torch.amp.autocast( + "cuda", enabled=params.use_autocast, dtype=params.dtype ): loss, _ = compute_loss( params=params, diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index d0bf234ae..31fc4f126 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -29,7 +29,7 @@ from loss import ( WavReconstructionLoss, ) from torch import nn -from torch.cuda.amp import autocast +from torch.amp import autocast class Encodec(nn.Module): diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index 1c0f252dc..6fd6d219b 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -25,7 +25,7 @@ from loss import ( KLDivergenceLoss, MelSpectrogramLoss, ) -from torch.cuda.amp import autocast +from torch.amp import autocast from utils import get_segments AVAILABLE_GENERATERS = { diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py index 4167f356b..5662a6bb0 100755 --- a/egs/wenetspeech4tts/TTS/valle/train.py +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -1103,7 +1103,9 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0) + scaler = GradScaler( + "cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 + ) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"])