Update train.py

This commit is contained in:
Masoud Azizi 2024-11-30 18:47:25 +03:30 committed by GitHub
parent a1ade8ecb7
commit b7b65205d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -30,7 +30,7 @@ import torch.nn as nn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
@ -396,7 +396,7 @@ def train_one_epoch(
loss_info["samples"] = batch_size
try:
with autocast(enabled=params.use_fp16):
with autocast('cuda',enabled=params.use_fp16):
# forward discriminator
loss_d, stats_d = model(
text=tokens,
@ -414,7 +414,7 @@ def train_one_epoch(
scaler.scale(loss_d).backward()
scaler.step(optimizer_d)
with autocast(enabled=params.use_fp16):
with autocast('cuda',enabled=params.use_fp16):
# forward generator
loss_g, stats_g = model(
text=tokens,
@ -673,7 +673,7 @@ def scan_pessimistic_batches_for_oom(
)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
with autocast('cuda',enabled=params.use_fp16):
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
@ -686,7 +686,7 @@ def scan_pessimistic_batches_for_oom(
optimizer_d.zero_grad()
loss_d.backward()
# for generator
with autocast(enabled=params.use_fp16):
with autocast('cuda',enabled=params.use_fp16):
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
@ -838,7 +838,7 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = GradScaler('cuda',enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])