mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
Update vits.py
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
This commit is contained in:
parent
b7b65205d1
commit
046d257570
@ -25,7 +25,7 @@ from loss import (
|
|||||||
KLDivergenceLoss,
|
KLDivergenceLoss,
|
||||||
MelSpectrogramLoss,
|
MelSpectrogramLoss,
|
||||||
)
|
)
|
||||||
from torch.cuda.amp import autocast
|
from torch.amp import autocast
|
||||||
from utils import get_segments
|
from utils import get_segments
|
||||||
|
|
||||||
AVAILABLE_GENERATERS = {
|
AVAILABLE_GENERATERS = {
|
||||||
@ -410,7 +410,7 @@ class VITS(nn.Module):
|
|||||||
p = self.discriminator(speech_)
|
p = self.discriminator(speech_)
|
||||||
|
|
||||||
# calculate losses
|
# calculate losses
|
||||||
with autocast(enabled=False):
|
with autocast('cuda',enabled=False):
|
||||||
if not return_sample:
|
if not return_sample:
|
||||||
mel_loss = self.mel_loss(speech_hat_, speech_)
|
mel_loss = self.mel_loss(speech_hat_, speech_)
|
||||||
else:
|
else:
|
||||||
@ -518,7 +518,7 @@ class VITS(nn.Module):
|
|||||||
p = self.discriminator(speech_)
|
p = self.discriminator(speech_)
|
||||||
|
|
||||||
# calculate losses
|
# calculate losses
|
||||||
with autocast(enabled=False):
|
with autocast('cuda',enabled=False):
|
||||||
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
||||||
loss = real_loss + fake_loss
|
loss = real_loss + fake_loss
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user