mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-23 08:46:14 +00:00
Merge 046d25757087f155249102eab936e819f2d004e5 into 32b7a449e7ed87efdf0a49f74b01c846e831c8a3
This commit is contained in:
commit
5cf74ab506
@ -30,7 +30,7 @@ import torch.nn as nn
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from tokenizer import Tokenizer
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -396,7 +396,7 @@ def train_one_epoch(
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
with autocast('cuda',enabled=params.use_fp16):
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
text=tokens,
|
||||
@ -414,7 +414,7 @@ def train_one_epoch(
|
||||
scaler.scale(loss_d).backward()
|
||||
scaler.step(optimizer_d)
|
||||
|
||||
with autocast(enabled=params.use_fp16):
|
||||
with autocast('cuda',enabled=params.use_fp16):
|
||||
# forward generator
|
||||
loss_g, stats_g = model(
|
||||
text=tokens,
|
||||
@ -673,7 +673,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
)
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
with autocast('cuda',enabled=params.use_fp16):
|
||||
loss_d, stats_d = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
@ -686,7 +686,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer_d.zero_grad()
|
||||
loss_d.backward()
|
||||
# for generator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
with autocast('cuda',enabled=params.use_fp16):
|
||||
loss_g, stats_g = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
@ -838,7 +838,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
scaler = GradScaler('cuda',enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
@ -25,7 +25,7 @@ from loss import (
|
||||
KLDivergenceLoss,
|
||||
MelSpectrogramLoss,
|
||||
)
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.amp import autocast
|
||||
from utils import get_segments
|
||||
|
||||
AVAILABLE_GENERATERS = {
|
||||
@ -410,7 +410,7 @@ class VITS(nn.Module):
|
||||
p = self.discriminator(speech_)
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
with autocast('cuda',enabled=False):
|
||||
if not return_sample:
|
||||
mel_loss = self.mel_loss(speech_hat_, speech_)
|
||||
else:
|
||||
@ -518,7 +518,7 @@ class VITS(nn.Module):
|
||||
p = self.discriminator(speech_)
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
with autocast('cuda',enabled=False):
|
||||
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
||||
loss = real_loss + fake_loss
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user