Update vits.py

FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
This commit is contained in:
Masoud Azizi 2024-11-30 18:48:18 +03:30 committed by GitHub
parent a1ade8ecb7
commit 6c225a92b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,7 @@ from loss import (
KLDivergenceLoss, KLDivergenceLoss,
MelSpectrogramLoss, MelSpectrogramLoss,
) )
from torch.cuda.amp import autocast from torch.amp import autocast
from utils import get_segments from utils import get_segments
AVAILABLE_GENERATERS = { AVAILABLE_GENERATERS = {
@ -410,7 +410,7 @@ class VITS(nn.Module):
p = self.discriminator(speech_) p = self.discriminator(speech_)
# calculate losses # calculate losses
with autocast(enabled=False): with autocast('cuda',enabled=False):
if not return_sample: if not return_sample:
mel_loss = self.mel_loss(speech_hat_, speech_) mel_loss = self.mel_loss(speech_hat_, speech_)
else: else:
@ -518,7 +518,7 @@ class VITS(nn.Module):
p = self.discriminator(speech_) p = self.discriminator(speech_)
# calculate losses # calculate losses
with autocast(enabled=False): with autocast('cuda',enabled=False):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss loss = real_loss + fake_loss