Fix bugs introduced by previous commits

Along with reformatting to pass black lint.

- egs/libritts/ASR/zipformer/train.py
- egs/libritts/CODEC/encodec/encodec.py
- egs/ljspeech/TTS/vits/vits.py
- egs/wenetspeech4tts/TTS/valle/train.py
This commit is contained in:
Li Peng 2024-12-03 14:37:14 +08:00
parent 30ba83a7b2
commit b29ab59cce
4 changed files with 9 additions and 7 deletions

View File

@ -1049,8 +1049,8 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
with torch.amp.autocast("cuda",
enabled=params.use_autocast, dtype=params.dtype
with torch.amp.autocast(
"cuda", enabled=params.use_autocast, dtype=params.dtype
):
loss, loss_info = compute_loss(
params=params,
@ -1478,8 +1478,8 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.amp.autocast("cuda",
enabled=params.use_autocast, dtype=params.dtype
with torch.amp.autocast(
"cuda", enabled=params.use_autocast, dtype=params.dtype
):
loss, _ = compute_loss(
params=params,

View File

@ -29,7 +29,7 @@ from loss import (
WavReconstructionLoss,
)
from torch import nn
from torch.cuda.amp import autocast
from torch.amp import autocast
class Encodec(nn.Module):

View File

@ -25,7 +25,7 @@ from loss import (
KLDivergenceLoss,
MelSpectrogramLoss,
)
from torch.cuda.amp import autocast
from torch.amp import autocast
from utils import get_segments
AVAILABLE_GENERATERS = {

View File

@ -1103,7 +1103,9 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler("cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0)
scaler = GradScaler(
"cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])