Update train.py

This commit is contained in:
Masoud Azizi 2024-11-30 18:47:25 +03:30 committed by GitHub
parent a1ade8ecb7
commit b7b65205d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -30,7 +30,7 @@ import torch.nn as nn
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from tokenizer import Tokenizer from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -396,7 +396,7 @@ def train_one_epoch(
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
try: try:
with autocast(enabled=params.use_fp16): with autocast('cuda',enabled=params.use_fp16):
# forward discriminator # forward discriminator
loss_d, stats_d = model( loss_d, stats_d = model(
text=tokens, text=tokens,
@ -414,7 +414,7 @@ def train_one_epoch(
scaler.scale(loss_d).backward() scaler.scale(loss_d).backward()
scaler.step(optimizer_d) scaler.step(optimizer_d)
with autocast(enabled=params.use_fp16): with autocast('cuda',enabled=params.use_fp16):
# forward generator # forward generator
loss_g, stats_g = model( loss_g, stats_g = model(
text=tokens, text=tokens,
@ -673,7 +673,7 @@ def scan_pessimistic_batches_for_oom(
) )
try: try:
# for discriminator # for discriminator
with autocast(enabled=params.use_fp16): with autocast('cuda',enabled=params.use_fp16):
loss_d, stats_d = model( loss_d, stats_d = model(
text=tokens, text=tokens,
text_lengths=tokens_lens, text_lengths=tokens_lens,
@ -686,7 +686,7 @@ def scan_pessimistic_batches_for_oom(
optimizer_d.zero_grad() optimizer_d.zero_grad()
loss_d.backward() loss_d.backward()
# for generator # for generator
with autocast(enabled=params.use_fp16): with autocast('cuda',enabled=params.use_fp16):
loss_g, stats_g = model( loss_g, stats_g = model(
text=tokens, text=tokens,
text_lengths=tokens_lens, text_lengths=tokens_lens,
@ -838,7 +838,7 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler('cuda',enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])