Merge 046d25757087f155249102eab936e819f2d004e5 into 32b7a449e7ed87efdf0a49f74b01c846e831c8a3

This commit is contained in:
Masoud Azizi 2024-12-09 12:32:00 +00:00 committed by GitHub
commit 5cf74ab506
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 9 deletions

View File

@ -30,7 +30,7 @@ import torch.nn as nn
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from tokenizer import Tokenizer from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -396,7 +396,7 @@ def train_one_epoch(
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
try: try:
with autocast(enabled=params.use_fp16): with autocast('cuda',enabled=params.use_fp16):
# forward discriminator # forward discriminator
loss_d, stats_d = model( loss_d, stats_d = model(
text=tokens, text=tokens,
@ -414,7 +414,7 @@ def train_one_epoch(
scaler.scale(loss_d).backward() scaler.scale(loss_d).backward()
scaler.step(optimizer_d) scaler.step(optimizer_d)
with autocast(enabled=params.use_fp16): with autocast('cuda',enabled=params.use_fp16):
# forward generator # forward generator
loss_g, stats_g = model( loss_g, stats_g = model(
text=tokens, text=tokens,
@ -673,7 +673,7 @@ def scan_pessimistic_batches_for_oom(
) )
try: try:
# for discriminator # for discriminator
with autocast(enabled=params.use_fp16): with autocast('cuda',enabled=params.use_fp16):
loss_d, stats_d = model( loss_d, stats_d = model(
text=tokens, text=tokens,
text_lengths=tokens_lens, text_lengths=tokens_lens,
@ -686,7 +686,7 @@ def scan_pessimistic_batches_for_oom(
optimizer_d.zero_grad() optimizer_d.zero_grad()
loss_d.backward() loss_d.backward()
# for generator # for generator
with autocast(enabled=params.use_fp16): with autocast('cuda',enabled=params.use_fp16):
loss_g, stats_g = model( loss_g, stats_g = model(
text=tokens, text=tokens,
text_lengths=tokens_lens, text_lengths=tokens_lens,
@ -838,7 +838,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler('cuda',enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])

View File

@ -25,7 +25,7 @@ from loss import (
KLDivergenceLoss, KLDivergenceLoss,
MelSpectrogramLoss, MelSpectrogramLoss,
) )
from torch.cuda.amp import autocast from torch.amp import autocast
from utils import get_segments from utils import get_segments
AVAILABLE_GENERATERS = { AVAILABLE_GENERATERS = {
@ -410,7 +410,7 @@ class VITS(nn.Module):
p = self.discriminator(speech_) p = self.discriminator(speech_)
# calculate losses # calculate losses
with autocast(enabled=False): with autocast('cuda',enabled=False):
if not return_sample: if not return_sample:
mel_loss = self.mel_loss(speech_hat_, speech_) mel_loss = self.mel_loss(speech_hat_, speech_)
else: else:
@ -518,7 +518,7 @@ class VITS(nn.Module):
p = self.discriminator(speech_) p = self.discriminator(speech_)
# calculate losses # calculate losses
with autocast(enabled=False): with autocast('cuda',enabled=False):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss loss = real_loss + fake_loss