mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fixed loss functions & scaling factors
This commit is contained in:
parent
58f6562824
commit
01cc307664
@ -142,10 +142,10 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
|
||||
return F.mse_loss(x, x.new_zeros(x.size()))
|
||||
|
||||
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu(x.new_ones(x.size()) - x).mean()
|
||||
return F.relu(torch.ones_like(x) - x).mean()
|
||||
|
||||
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu(x.new_ones(x.size()) + x).mean()
|
||||
return F.relu(torch.ones_like(x) + x).mean()
|
||||
|
||||
|
||||
class FeatureLoss(torch.nn.Module):
|
||||
@ -200,7 +200,7 @@ class FeatureLoss(torch.nn.Module):
|
||||
feats_ = feats_[:-1]
|
||||
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||
feat_match_loss_ += (
|
||||
(feat_hat_ - feat_).abs() / (feat_.abs().mean())
|
||||
F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean())
|
||||
).mean()
|
||||
if self.average_by_layers:
|
||||
feat_match_loss_ /= j + 1
|
||||
@ -272,9 +272,16 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
||||
mel = wav_to_spec(x.squeeze(1))
|
||||
|
||||
mel_loss += F.l1_loss(
|
||||
mel_hat, mel, reduce=True, reduction="mean"
|
||||
) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean")
|
||||
mel_loss += (
|
||||
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
|
||||
+ (
|
||||
(
|
||||
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
|
||||
** 2
|
||||
).mean(dim=-2)
|
||||
** 0.5
|
||||
).mean()
|
||||
)
|
||||
|
||||
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
||||
# mel = self.wav_to_spec(x.squeeze(1))
|
||||
@ -307,7 +314,7 @@ class WavReconstructionLoss(torch.nn.Module):
|
||||
Tensor: Wav loss value.
|
||||
|
||||
"""
|
||||
wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean")
|
||||
wav_loss = F.l1_loss(x, x_hat)
|
||||
|
||||
return wav_loss
|
||||
|
||||
|
@ -4,16 +4,40 @@ from bisect import bisect_right
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
# It will be replaced with huggingface optimization
|
||||
class WarmUpLR(_LRScheduler):
|
||||
"""warmup_training learning rate scheduler
|
||||
Args:
|
||||
optimizer: optimzier(e.g. SGD)
|
||||
total_iters: totoal_iters of warmup phase
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, iter_per_epoch, warmup_epoch, last_epoch=-1):
|
||||
|
||||
self.total_iters = iter_per_epoch * warmup_epoch
|
||||
self.iter_per_epoch = iter_per_epoch
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
"""we will use the first m batches, and set the learning
|
||||
rate to base_lr * m / total_iters
|
||||
"""
|
||||
return [
|
||||
base_lr * self.last_epoch / (self.total_iters + 1e-8)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
||||
class WarmupLrScheduler(_LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
warmup_epoch=500,
|
||||
warmup_iter=500,
|
||||
warmup_ratio=5e-4,
|
||||
warmup="exp",
|
||||
last_epoch=-1,
|
||||
):
|
||||
self.warmup_epoch = warmup_epoch
|
||||
self.warmup_iter = warmup_iter
|
||||
self.warmup_ratio = warmup_ratio
|
||||
self.warmup = warmup
|
||||
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
|
||||
@ -24,7 +48,7 @@ class WarmupLrScheduler(_LRScheduler):
|
||||
return lrs
|
||||
|
||||
def get_lr_ratio(self):
|
||||
if self.last_epoch < self.warmup_epoch:
|
||||
if self.last_epoch < self.warmup_iter:
|
||||
ratio = self.get_warmup_ratio()
|
||||
else:
|
||||
ratio = self.get_main_ratio()
|
||||
@ -35,7 +59,7 @@ class WarmupLrScheduler(_LRScheduler):
|
||||
|
||||
def get_warmup_ratio(self):
|
||||
assert self.warmup in ("linear", "exp")
|
||||
alpha = self.last_epoch / self.warmup_epoch
|
||||
alpha = self.last_epoch / self.warmup_iter
|
||||
if self.warmup == "linear":
|
||||
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
|
||||
elif self.warmup == "exp":
|
||||
@ -48,22 +72,22 @@ class WarmupPolyLrScheduler(WarmupLrScheduler):
|
||||
self,
|
||||
optimizer,
|
||||
power,
|
||||
max_epoch,
|
||||
warmup_epoch=500,
|
||||
max_iter,
|
||||
warmup_iter=500,
|
||||
warmup_ratio=5e-4,
|
||||
warmup="exp",
|
||||
last_epoch=-1,
|
||||
):
|
||||
self.power = power
|
||||
self.max_epoch = max_epoch
|
||||
self.max_iter = max_iter
|
||||
super(WarmupPolyLrScheduler, self).__init__(
|
||||
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
||||
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
|
||||
)
|
||||
|
||||
def get_main_ratio(self):
|
||||
real_epoch = self.last_epoch - self.warmup_epoch
|
||||
real_max_epoch = self.max_epoch - self.warmup_epoch
|
||||
alpha = real_epoch / real_max_epoch
|
||||
real_iter = self.last_epoch - self.warmup_iter
|
||||
real_max_iter = self.max_iter - self.warmup_iter
|
||||
alpha = real_iter / real_max_iter
|
||||
ratio = (1 - alpha) ** self.power
|
||||
return ratio
|
||||
|
||||
@ -74,7 +98,7 @@ class WarmupExpLrScheduler(WarmupLrScheduler):
|
||||
optimizer,
|
||||
gamma,
|
||||
interval=1,
|
||||
warmup_epoch=500,
|
||||
warmup_iter=500,
|
||||
warmup_ratio=5e-4,
|
||||
warmup="exp",
|
||||
last_epoch=-1,
|
||||
@ -82,12 +106,12 @@ class WarmupExpLrScheduler(WarmupLrScheduler):
|
||||
self.gamma = gamma
|
||||
self.interval = interval
|
||||
super(WarmupExpLrScheduler, self).__init__(
|
||||
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
||||
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
|
||||
)
|
||||
|
||||
def get_main_ratio(self):
|
||||
real_epoch = self.last_epoch - self.warmup_epoch
|
||||
ratio = self.gamma ** (real_epoch // self.interval)
|
||||
real_iter = self.last_epoch - self.warmup_iter
|
||||
ratio = self.gamma ** (real_iter // self.interval)
|
||||
return ratio
|
||||
|
||||
|
||||
@ -95,25 +119,26 @@ class WarmupCosineLrScheduler(WarmupLrScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
max_epoch,
|
||||
max_iter,
|
||||
eta_ratio=0,
|
||||
warmup_epoch=500,
|
||||
warmup_iter=500,
|
||||
warmup_ratio=5e-4,
|
||||
warmup="exp",
|
||||
last_epoch=-1,
|
||||
):
|
||||
self.eta_ratio = eta_ratio
|
||||
self.max_epoch = max_epoch
|
||||
self.max_iter = max_iter
|
||||
super(WarmupCosineLrScheduler, self).__init__(
|
||||
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
||||
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
|
||||
)
|
||||
|
||||
def get_main_ratio(self):
|
||||
real_max_epoch = self.max_epoch - self.warmup_epoch
|
||||
real_iter = self.last_epoch - self.warmup_iter
|
||||
real_max_iter = self.max_iter - self.warmup_iter
|
||||
return (
|
||||
self.eta_ratio
|
||||
+ (1 - self.eta_ratio)
|
||||
* (1 + math.cos(math.pi * self.last_epoch / real_max_epoch))
|
||||
* (1 + math.cos(math.pi * self.last_epoch / real_max_iter))
|
||||
/ 2
|
||||
)
|
||||
|
||||
@ -124,7 +149,7 @@ class WarmupStepLrScheduler(WarmupLrScheduler):
|
||||
optimizer,
|
||||
milestones: list,
|
||||
gamma=0.1,
|
||||
warmup_epoch=500,
|
||||
warmup_iter=500,
|
||||
warmup_ratio=5e-4,
|
||||
warmup="exp",
|
||||
last_epoch=-1,
|
||||
@ -132,10 +157,10 @@ class WarmupStepLrScheduler(WarmupLrScheduler):
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
super(WarmupStepLrScheduler, self).__init__(
|
||||
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
||||
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
|
||||
)
|
||||
|
||||
def get_main_ratio(self):
|
||||
real_epoch = self.last_epoch - self.warmup_epoch
|
||||
ratio = self.gamma ** bisect_right(self.milestones, real_epoch)
|
||||
real_iter = self.last_epoch - self.warmup_iter
|
||||
ratio = self.gamma ** bisect_right(self.milestones, real_iter)
|
||||
return ratio
|
||||
|
@ -187,6 +187,7 @@ def get_params() -> AttributeDict:
|
||||
"valid_interval": 200,
|
||||
"env_info": get_env_info(),
|
||||
"sampling_rate": 24000,
|
||||
"audio_normalization": False,
|
||||
"chunk_size": 1.0, # in seconds
|
||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss
|
||||
@ -276,13 +277,13 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
}
|
||||
discriminator_params = {
|
||||
"stft_discriminator_n_filters": 32,
|
||||
"discriminator_epoch_start": 3,
|
||||
"discriminator_epoch_start": 5,
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [256, 512, 128],
|
||||
"win_lengths": [1024, 2048, 512],
|
||||
}
|
||||
inference_params = {
|
||||
"target_bw": 12,
|
||||
"target_bw": 6,
|
||||
}
|
||||
|
||||
params.update(generator_params)
|
||||
@ -353,6 +354,11 @@ def prepare_input(
|
||||
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
||||
]
|
||||
|
||||
if params.audio_normalization:
|
||||
mean = audio.mean(dim=-1, keepdim=True)
|
||||
std = audio.std(dim=-1, keepdim=True)
|
||||
audio = (audio - mean) / (std + 1e-7)
|
||||
|
||||
return audio, audio_lens, features, features_lens
|
||||
|
||||
|
||||
@ -532,6 +538,10 @@ def train_one_epoch(
|
||||
save_bad_model()
|
||||
raise
|
||||
|
||||
# step per iteration
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
@ -1009,16 +1019,16 @@ def run(rank, world_size, args):
|
||||
|
||||
scheduler_g = WarmupCosineLrScheduler(
|
||||
optimizer=optimizer_g,
|
||||
max_epoch=params.num_epochs,
|
||||
max_iter=params.num_epochs * 1500,
|
||||
eta_ratio=0.1,
|
||||
warmup_epoch=params.discriminator_epoch_start,
|
||||
warmup_iter=params.discriminator_epoch_start * 1500,
|
||||
warmup_ratio=1e-4,
|
||||
)
|
||||
scheduler_d = WarmupCosineLrScheduler(
|
||||
optimizer=optimizer_d,
|
||||
max_epoch=params.num_epochs,
|
||||
max_iter=params.num_epochs * 1500,
|
||||
eta_ratio=0.1,
|
||||
warmup_epoch=params.discriminator_epoch_start,
|
||||
warmup_iter=params.discriminator_epoch_start * 1500,
|
||||
warmup_ratio=1e-4,
|
||||
)
|
||||
|
||||
@ -1128,10 +1138,6 @@ def run(rank, world_size, args):
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
# step per epoch
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user