mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Update vits.py
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
This commit is contained in:
parent
b7b65205d1
commit
046d257570
@ -25,7 +25,7 @@ from loss import (
|
||||
KLDivergenceLoss,
|
||||
MelSpectrogramLoss,
|
||||
)
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.amp import autocast
|
||||
from utils import get_segments
|
||||
|
||||
AVAILABLE_GENERATERS = {
|
||||
@ -410,7 +410,7 @@ class VITS(nn.Module):
|
||||
p = self.discriminator(speech_)
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
with autocast('cuda',enabled=False):
|
||||
if not return_sample:
|
||||
mel_loss = self.mel_loss(speech_hat_, speech_)
|
||||
else:
|
||||
@ -518,7 +518,7 @@ class VITS(nn.Module):
|
||||
p = self.discriminator(speech_)
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
with autocast('cuda',enabled=False):
|
||||
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
||||
loss = real_loss + fake_loss
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user