From b7b65205d15577e28e21111b78aa68219f1ab7d3 Mon Sep 17 00:00:00 2001 From: Masoud Azizi Date: Sat, 30 Nov 2024 18:47:25 +0330 Subject: [PATCH] Update train.py --- egs/ljspeech/TTS/vits/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 184ae79af..6f69ca9ad 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -30,7 +30,7 @@ import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast +from torch.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -396,7 +396,7 @@ def train_one_epoch( loss_info["samples"] = batch_size try: - with autocast(enabled=params.use_fp16): + with autocast('cuda',enabled=params.use_fp16): # forward discriminator loss_d, stats_d = model( text=tokens, @@ -414,7 +414,7 @@ def train_one_epoch( scaler.scale(loss_d).backward() scaler.step(optimizer_d) - with autocast(enabled=params.use_fp16): + with autocast('cuda',enabled=params.use_fp16): # forward generator loss_g, stats_g = model( text=tokens, @@ -673,7 +673,7 @@ def scan_pessimistic_batches_for_oom( ) try: # for discriminator - with autocast(enabled=params.use_fp16): + with autocast('cuda',enabled=params.use_fp16): loss_d, stats_d = model( text=tokens, text_lengths=tokens_lens, @@ -686,7 +686,7 @@ def scan_pessimistic_batches_for_oom( optimizer_d.zero_grad() loss_d.backward() # for generator - with autocast(enabled=params.use_fp16): + with autocast('cuda',enabled=params.use_fp16): loss_g, stats_g = model( text=tokens, text_lengths=tokens_lens, @@ -838,7 +838,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler('cuda',enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"])