fixed loss functions & scaling factors

This commit is contained in:
JinZr 2024-10-07 01:03:26 +08:00
parent 58f6562824
commit 01cc307664
3 changed files with 80 additions and 42 deletions

View File

@ -142,10 +142,10 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
return F.mse_loss(x, x.new_zeros(x.size())) return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) - x).mean() return F.relu(torch.ones_like(x) - x).mean()
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) + x).mean() return F.relu(torch.ones_like(x) + x).mean()
class FeatureLoss(torch.nn.Module): class FeatureLoss(torch.nn.Module):
@ -200,7 +200,7 @@ class FeatureLoss(torch.nn.Module):
feats_ = feats_[:-1] feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += ( feat_match_loss_ += (
(feat_hat_ - feat_).abs() / (feat_.abs().mean()) F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean())
).mean() ).mean()
if self.average_by_layers: if self.average_by_layers:
feat_match_loss_ /= j + 1 feat_match_loss_ /= j + 1
@ -272,9 +272,16 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
mel_hat = wav_to_spec(x_hat.squeeze(1)) mel_hat = wav_to_spec(x_hat.squeeze(1))
mel = wav_to_spec(x.squeeze(1)) mel = wav_to_spec(x.squeeze(1))
mel_loss += F.l1_loss( mel_loss += (
mel_hat, mel, reduce=True, reduction="mean" F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean") + (
(
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
** 2
).mean(dim=-2)
** 0.5
).mean()
)
# mel_hat = self.wav_to_spec(x_hat.squeeze(1)) # mel_hat = self.wav_to_spec(x_hat.squeeze(1))
# mel = self.wav_to_spec(x.squeeze(1)) # mel = self.wav_to_spec(x.squeeze(1))
@ -307,7 +314,7 @@ class WavReconstructionLoss(torch.nn.Module):
Tensor: Wav loss value. Tensor: Wav loss value.
""" """
wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean") wav_loss = F.l1_loss(x, x_hat)
return wav_loss return wav_loss

View File

@ -4,16 +4,40 @@ from bisect import bisect_right
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
# It will be replaced with huggingface optimization
class WarmUpLR(_LRScheduler):
"""warmup_training learning rate scheduler
Args:
optimizer: optimzier(e.g. SGD)
total_iters: totoal_iters of warmup phase
"""
def __init__(self, optimizer, iter_per_epoch, warmup_epoch, last_epoch=-1):
self.total_iters = iter_per_epoch * warmup_epoch
self.iter_per_epoch = iter_per_epoch
super().__init__(optimizer, last_epoch)
def get_lr(self):
"""we will use the first m batches, and set the learning
rate to base_lr * m / total_iters
"""
return [
base_lr * self.last_epoch / (self.total_iters + 1e-8)
for base_lr in self.base_lrs
]
class WarmupLrScheduler(_LRScheduler): class WarmupLrScheduler(_LRScheduler):
def __init__( def __init__(
self, self,
optimizer, optimizer,
warmup_epoch=500, warmup_iter=500,
warmup_ratio=5e-4, warmup_ratio=5e-4,
warmup="exp", warmup="exp",
last_epoch=-1, last_epoch=-1,
): ):
self.warmup_epoch = warmup_epoch self.warmup_iter = warmup_iter
self.warmup_ratio = warmup_ratio self.warmup_ratio = warmup_ratio
self.warmup = warmup self.warmup = warmup
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
@ -24,7 +48,7 @@ class WarmupLrScheduler(_LRScheduler):
return lrs return lrs
def get_lr_ratio(self): def get_lr_ratio(self):
if self.last_epoch < self.warmup_epoch: if self.last_epoch < self.warmup_iter:
ratio = self.get_warmup_ratio() ratio = self.get_warmup_ratio()
else: else:
ratio = self.get_main_ratio() ratio = self.get_main_ratio()
@ -35,7 +59,7 @@ class WarmupLrScheduler(_LRScheduler):
def get_warmup_ratio(self): def get_warmup_ratio(self):
assert self.warmup in ("linear", "exp") assert self.warmup in ("linear", "exp")
alpha = self.last_epoch / self.warmup_epoch alpha = self.last_epoch / self.warmup_iter
if self.warmup == "linear": if self.warmup == "linear":
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
elif self.warmup == "exp": elif self.warmup == "exp":
@ -48,22 +72,22 @@ class WarmupPolyLrScheduler(WarmupLrScheduler):
self, self,
optimizer, optimizer,
power, power,
max_epoch, max_iter,
warmup_epoch=500, warmup_iter=500,
warmup_ratio=5e-4, warmup_ratio=5e-4,
warmup="exp", warmup="exp",
last_epoch=-1, last_epoch=-1,
): ):
self.power = power self.power = power
self.max_epoch = max_epoch self.max_iter = max_iter
super(WarmupPolyLrScheduler, self).__init__( super(WarmupPolyLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
) )
def get_main_ratio(self): def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch real_iter = self.last_epoch - self.warmup_iter
real_max_epoch = self.max_epoch - self.warmup_epoch real_max_iter = self.max_iter - self.warmup_iter
alpha = real_epoch / real_max_epoch alpha = real_iter / real_max_iter
ratio = (1 - alpha) ** self.power ratio = (1 - alpha) ** self.power
return ratio return ratio
@ -74,7 +98,7 @@ class WarmupExpLrScheduler(WarmupLrScheduler):
optimizer, optimizer,
gamma, gamma,
interval=1, interval=1,
warmup_epoch=500, warmup_iter=500,
warmup_ratio=5e-4, warmup_ratio=5e-4,
warmup="exp", warmup="exp",
last_epoch=-1, last_epoch=-1,
@ -82,12 +106,12 @@ class WarmupExpLrScheduler(WarmupLrScheduler):
self.gamma = gamma self.gamma = gamma
self.interval = interval self.interval = interval
super(WarmupExpLrScheduler, self).__init__( super(WarmupExpLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
) )
def get_main_ratio(self): def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch real_iter = self.last_epoch - self.warmup_iter
ratio = self.gamma ** (real_epoch // self.interval) ratio = self.gamma ** (real_iter // self.interval)
return ratio return ratio
@ -95,25 +119,26 @@ class WarmupCosineLrScheduler(WarmupLrScheduler):
def __init__( def __init__(
self, self,
optimizer, optimizer,
max_epoch, max_iter,
eta_ratio=0, eta_ratio=0,
warmup_epoch=500, warmup_iter=500,
warmup_ratio=5e-4, warmup_ratio=5e-4,
warmup="exp", warmup="exp",
last_epoch=-1, last_epoch=-1,
): ):
self.eta_ratio = eta_ratio self.eta_ratio = eta_ratio
self.max_epoch = max_epoch self.max_iter = max_iter
super(WarmupCosineLrScheduler, self).__init__( super(WarmupCosineLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
) )
def get_main_ratio(self): def get_main_ratio(self):
real_max_epoch = self.max_epoch - self.warmup_epoch real_iter = self.last_epoch - self.warmup_iter
real_max_iter = self.max_iter - self.warmup_iter
return ( return (
self.eta_ratio self.eta_ratio
+ (1 - self.eta_ratio) + (1 - self.eta_ratio)
* (1 + math.cos(math.pi * self.last_epoch / real_max_epoch)) * (1 + math.cos(math.pi * self.last_epoch / real_max_iter))
/ 2 / 2
) )
@ -124,7 +149,7 @@ class WarmupStepLrScheduler(WarmupLrScheduler):
optimizer, optimizer,
milestones: list, milestones: list,
gamma=0.1, gamma=0.1,
warmup_epoch=500, warmup_iter=500,
warmup_ratio=5e-4, warmup_ratio=5e-4,
warmup="exp", warmup="exp",
last_epoch=-1, last_epoch=-1,
@ -132,10 +157,10 @@ class WarmupStepLrScheduler(WarmupLrScheduler):
self.milestones = milestones self.milestones = milestones
self.gamma = gamma self.gamma = gamma
super(WarmupStepLrScheduler, self).__init__( super(WarmupStepLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
) )
def get_main_ratio(self): def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch real_iter = self.last_epoch - self.warmup_iter
ratio = self.gamma ** bisect_right(self.milestones, real_epoch) ratio = self.gamma ** bisect_right(self.milestones, real_iter)
return ratio return ratio

View File

@ -187,6 +187,7 @@ def get_params() -> AttributeDict:
"valid_interval": 200, "valid_interval": 200,
"env_info": get_env_info(), "env_info": get_env_info(),
"sampling_rate": 24000, "sampling_rate": 24000,
"audio_normalization": False,
"chunk_size": 1.0, # in seconds "chunk_size": 1.0, # in seconds
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss "lambda_wav": 1.0, # loss scaling coefficient for waveform loss
@ -276,13 +277,13 @@ def get_model(params: AttributeDict) -> nn.Module:
} }
discriminator_params = { discriminator_params = {
"stft_discriminator_n_filters": 32, "stft_discriminator_n_filters": 32,
"discriminator_epoch_start": 3, "discriminator_epoch_start": 5,
"n_ffts": [1024, 2048, 512], "n_ffts": [1024, 2048, 512],
"hop_lengths": [256, 512, 128], "hop_lengths": [256, 512, 128],
"win_lengths": [1024, 2048, 512], "win_lengths": [1024, 2048, 512],
} }
inference_params = { inference_params = {
"target_bw": 12, "target_bw": 6,
} }
params.update(generator_params) params.update(generator_params)
@ -353,6 +354,11 @@ def prepare_input(
:, params.sampling_rate : params.sampling_rate + params.sampling_rate :, params.sampling_rate : params.sampling_rate + params.sampling_rate
] ]
if params.audio_normalization:
mean = audio.mean(dim=-1, keepdim=True)
std = audio.std(dim=-1, keepdim=True)
audio = (audio - mean) / (std + 1e-7)
return audio, audio_lens, features, features_lens return audio, audio_lens, features, features_lens
@ -532,6 +538,10 @@ def train_one_epoch(
save_bad_model() save_bad_model()
raise raise
# step per iteration
scheduler_g.step()
scheduler_d.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
@ -1009,16 +1019,16 @@ def run(rank, world_size, args):
scheduler_g = WarmupCosineLrScheduler( scheduler_g = WarmupCosineLrScheduler(
optimizer=optimizer_g, optimizer=optimizer_g,
max_epoch=params.num_epochs, max_iter=params.num_epochs * 1500,
eta_ratio=0.1, eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start, warmup_iter=params.discriminator_epoch_start * 1500,
warmup_ratio=1e-4, warmup_ratio=1e-4,
) )
scheduler_d = WarmupCosineLrScheduler( scheduler_d = WarmupCosineLrScheduler(
optimizer=optimizer_d, optimizer=optimizer_d,
max_epoch=params.num_epochs, max_iter=params.num_epochs * 1500,
eta_ratio=0.1, eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start, warmup_iter=params.discriminator_epoch_start * 1500,
warmup_ratio=1e-4, warmup_ratio=1e-4,
) )
@ -1128,10 +1138,6 @@ def run(rank, world_size, args):
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
# step per epoch
scheduler_g.step()
scheduler_d.step()
logging.info("Done!") logging.info("Done!")
if world_size > 1: if world_size > 1: