fixed loss functions & scaling factors

This commit is contained in:
JinZr 2024-10-07 01:03:26 +08:00
parent 58f6562824
commit 01cc307664
3 changed files with 80 additions and 42 deletions

View File

@ -142,10 +142,10 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) - x).mean()
return F.relu(torch.ones_like(x) - x).mean()
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) + x).mean()
return F.relu(torch.ones_like(x) + x).mean()
class FeatureLoss(torch.nn.Module):
@ -200,7 +200,7 @@ class FeatureLoss(torch.nn.Module):
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += (
(feat_hat_ - feat_).abs() / (feat_.abs().mean())
F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean())
).mean()
if self.average_by_layers:
feat_match_loss_ /= j + 1
@ -272,9 +272,16 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
mel_hat = wav_to_spec(x_hat.squeeze(1))
mel = wav_to_spec(x.squeeze(1))
mel_loss += F.l1_loss(
mel_hat, mel, reduce=True, reduction="mean"
) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean")
mel_loss += (
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
+ (
(
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
** 2
).mean(dim=-2)
** 0.5
).mean()
)
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
# mel = self.wav_to_spec(x.squeeze(1))
@ -307,7 +314,7 @@ class WavReconstructionLoss(torch.nn.Module):
Tensor: Wav loss value.
"""
wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean")
wav_loss = F.l1_loss(x, x_hat)
return wav_loss

View File

@ -4,16 +4,40 @@ from bisect import bisect_right
from torch.optim.lr_scheduler import _LRScheduler
# It will be replaced with huggingface optimization
class WarmUpLR(_LRScheduler):
"""warmup_training learning rate scheduler
Args:
optimizer: optimzier(e.g. SGD)
total_iters: totoal_iters of warmup phase
"""
def __init__(self, optimizer, iter_per_epoch, warmup_epoch, last_epoch=-1):
self.total_iters = iter_per_epoch * warmup_epoch
self.iter_per_epoch = iter_per_epoch
super().__init__(optimizer, last_epoch)
def get_lr(self):
"""we will use the first m batches, and set the learning
rate to base_lr * m / total_iters
"""
return [
base_lr * self.last_epoch / (self.total_iters + 1e-8)
for base_lr in self.base_lrs
]
class WarmupLrScheduler(_LRScheduler):
def __init__(
self,
optimizer,
warmup_epoch=500,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.warmup_epoch = warmup_epoch
self.warmup_iter = warmup_iter
self.warmup_ratio = warmup_ratio
self.warmup = warmup
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
@ -24,7 +48,7 @@ class WarmupLrScheduler(_LRScheduler):
return lrs
def get_lr_ratio(self):
if self.last_epoch < self.warmup_epoch:
if self.last_epoch < self.warmup_iter:
ratio = self.get_warmup_ratio()
else:
ratio = self.get_main_ratio()
@ -35,7 +59,7 @@ class WarmupLrScheduler(_LRScheduler):
def get_warmup_ratio(self):
assert self.warmup in ("linear", "exp")
alpha = self.last_epoch / self.warmup_epoch
alpha = self.last_epoch / self.warmup_iter
if self.warmup == "linear":
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
elif self.warmup == "exp":
@ -48,22 +72,22 @@ class WarmupPolyLrScheduler(WarmupLrScheduler):
self,
optimizer,
power,
max_epoch,
warmup_epoch=500,
max_iter,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.power = power
self.max_epoch = max_epoch
self.max_iter = max_iter
super(WarmupPolyLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
real_max_epoch = self.max_epoch - self.warmup_epoch
alpha = real_epoch / real_max_epoch
real_iter = self.last_epoch - self.warmup_iter
real_max_iter = self.max_iter - self.warmup_iter
alpha = real_iter / real_max_iter
ratio = (1 - alpha) ** self.power
return ratio
@ -74,7 +98,7 @@ class WarmupExpLrScheduler(WarmupLrScheduler):
optimizer,
gamma,
interval=1,
warmup_epoch=500,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
@ -82,12 +106,12 @@ class WarmupExpLrScheduler(WarmupLrScheduler):
self.gamma = gamma
self.interval = interval
super(WarmupExpLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** (real_epoch // self.interval)
real_iter = self.last_epoch - self.warmup_iter
ratio = self.gamma ** (real_iter // self.interval)
return ratio
@ -95,25 +119,26 @@ class WarmupCosineLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
max_epoch,
max_iter,
eta_ratio=0,
warmup_epoch=500,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.eta_ratio = eta_ratio
self.max_epoch = max_epoch
self.max_iter = max_iter
super(WarmupCosineLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_max_epoch = self.max_epoch - self.warmup_epoch
real_iter = self.last_epoch - self.warmup_iter
real_max_iter = self.max_iter - self.warmup_iter
return (
self.eta_ratio
+ (1 - self.eta_ratio)
* (1 + math.cos(math.pi * self.last_epoch / real_max_epoch))
* (1 + math.cos(math.pi * self.last_epoch / real_max_iter))
/ 2
)
@ -124,7 +149,7 @@ class WarmupStepLrScheduler(WarmupLrScheduler):
optimizer,
milestones: list,
gamma=0.1,
warmup_epoch=500,
warmup_iter=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
@ -132,10 +157,10 @@ class WarmupStepLrScheduler(WarmupLrScheduler):
self.milestones = milestones
self.gamma = gamma
super(WarmupStepLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** bisect_right(self.milestones, real_epoch)
real_iter = self.last_epoch - self.warmup_iter
ratio = self.gamma ** bisect_right(self.milestones, real_iter)
return ratio

View File

@ -187,6 +187,7 @@ def get_params() -> AttributeDict:
"valid_interval": 200,
"env_info": get_env_info(),
"sampling_rate": 24000,
"audio_normalization": False,
"chunk_size": 1.0, # in seconds
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss
@ -276,13 +277,13 @@ def get_model(params: AttributeDict) -> nn.Module:
}
discriminator_params = {
"stft_discriminator_n_filters": 32,
"discriminator_epoch_start": 3,
"discriminator_epoch_start": 5,
"n_ffts": [1024, 2048, 512],
"hop_lengths": [256, 512, 128],
"win_lengths": [1024, 2048, 512],
}
inference_params = {
"target_bw": 12,
"target_bw": 6,
}
params.update(generator_params)
@ -353,6 +354,11 @@ def prepare_input(
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
]
if params.audio_normalization:
mean = audio.mean(dim=-1, keepdim=True)
std = audio.std(dim=-1, keepdim=True)
audio = (audio - mean) / (std + 1e-7)
return audio, audio_lens, features, features_lens
@ -532,6 +538,10 @@ def train_one_epoch(
save_bad_model()
raise
# step per iteration
scheduler_g.step()
scheduler_d.step()
if params.print_diagnostics and batch_idx == 5:
return
@ -1009,16 +1019,16 @@ def run(rank, world_size, args):
scheduler_g = WarmupCosineLrScheduler(
optimizer=optimizer_g,
max_epoch=params.num_epochs,
max_iter=params.num_epochs * 1500,
eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start,
warmup_iter=params.discriminator_epoch_start * 1500,
warmup_ratio=1e-4,
)
scheduler_d = WarmupCosineLrScheduler(
optimizer=optimizer_d,
max_epoch=params.num_epochs,
max_iter=params.num_epochs * 1500,
eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start,
warmup_iter=params.discriminator_epoch_start * 1500,
warmup_ratio=1e-4,
)
@ -1128,10 +1138,6 @@ def run(rank, world_size, args):
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
# step per epoch
scheduler_g.step()
scheduler_d.step()
logging.info("Done!")
if world_size > 1: