mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
Fix bugs introduced by previous commits
Along with reformatting to pass black lint. - egs/libritts/ASR/zipformer/train.py - egs/libritts/CODEC/encodec/encodec.py - egs/ljspeech/TTS/vits/vits.py - egs/wenetspeech4tts/TTS/valle/train.py
This commit is contained in:
parent
30ba83a7b2
commit
b29ab59cce
@ -1049,8 +1049,8 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
with torch.amp.autocast("cuda",
|
||||
enabled=params.use_autocast, dtype=params.dtype
|
||||
with torch.amp.autocast(
|
||||
"cuda", enabled=params.use_autocast, dtype=params.dtype
|
||||
):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
@ -1478,8 +1478,8 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.amp.autocast("cuda",
|
||||
enabled=params.use_autocast, dtype=params.dtype
|
||||
with torch.amp.autocast(
|
||||
"cuda", enabled=params.use_autocast, dtype=params.dtype
|
||||
):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
|
@ -29,7 +29,7 @@ from loss import (
|
||||
WavReconstructionLoss,
|
||||
)
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.amp import autocast
|
||||
|
||||
|
||||
class Encodec(nn.Module):
|
||||
|
@ -25,7 +25,7 @@ from loss import (
|
||||
KLDivergenceLoss,
|
||||
MelSpectrogramLoss,
|
||||
)
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.amp import autocast
|
||||
from utils import get_segments
|
||||
|
||||
AVAILABLE_GENERATERS = {
|
||||
|
@ -1103,7 +1103,9 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler("cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0)
|
||||
scaler = GradScaler(
|
||||
"cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
|
||||
)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user