Update vits.py

FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
This commit is contained in:
Masoud Azizi 2024-11-30 18:55:13 +03:30 committed by GitHub
parent b7b65205d1
commit 046d257570
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,7 @@ from loss import (
KLDivergenceLoss,
MelSpectrogramLoss,
)
from torch.cuda.amp import autocast
from torch.amp import autocast
from utils import get_segments
AVAILABLE_GENERATERS = {
@ -410,7 +410,7 @@ class VITS(nn.Module):
p = self.discriminator(speech_)
# calculate losses
with autocast(enabled=False):
with autocast('cuda',enabled=False):
if not return_sample:
mel_loss = self.mel_loss(speech_hat_, speech_)
else:
@ -518,7 +518,7 @@ class VITS(nn.Module):
p = self.discriminator(speech_)
# calculate losses
with autocast(enabled=False):
with autocast('cuda',enabled=False):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss