mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fixed loss functions & scaling factors
This commit is contained in:
parent
58f6562824
commit
01cc307664
@ -142,10 +142,10 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
|
|||||||
return F.mse_loss(x, x.new_zeros(x.size()))
|
return F.mse_loss(x, x.new_zeros(x.size()))
|
||||||
|
|
||||||
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return F.relu(x.new_ones(x.size()) - x).mean()
|
return F.relu(torch.ones_like(x) - x).mean()
|
||||||
|
|
||||||
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return F.relu(x.new_ones(x.size()) + x).mean()
|
return F.relu(torch.ones_like(x) + x).mean()
|
||||||
|
|
||||||
|
|
||||||
class FeatureLoss(torch.nn.Module):
|
class FeatureLoss(torch.nn.Module):
|
||||||
@ -200,7 +200,7 @@ class FeatureLoss(torch.nn.Module):
|
|||||||
feats_ = feats_[:-1]
|
feats_ = feats_[:-1]
|
||||||
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||||
feat_match_loss_ += (
|
feat_match_loss_ += (
|
||||||
(feat_hat_ - feat_).abs() / (feat_.abs().mean())
|
F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean())
|
||||||
).mean()
|
).mean()
|
||||||
if self.average_by_layers:
|
if self.average_by_layers:
|
||||||
feat_match_loss_ /= j + 1
|
feat_match_loss_ /= j + 1
|
||||||
@ -272,9 +272,16 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
|||||||
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
||||||
mel = wav_to_spec(x.squeeze(1))
|
mel = wav_to_spec(x.squeeze(1))
|
||||||
|
|
||||||
mel_loss += F.l1_loss(
|
mel_loss += (
|
||||||
mel_hat, mel, reduce=True, reduction="mean"
|
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
|
||||||
) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean")
|
+ (
|
||||||
|
(
|
||||||
|
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
|
||||||
|
** 2
|
||||||
|
).mean(dim=-2)
|
||||||
|
** 0.5
|
||||||
|
).mean()
|
||||||
|
)
|
||||||
|
|
||||||
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
||||||
# mel = self.wav_to_spec(x.squeeze(1))
|
# mel = self.wav_to_spec(x.squeeze(1))
|
||||||
@ -307,7 +314,7 @@ class WavReconstructionLoss(torch.nn.Module):
|
|||||||
Tensor: Wav loss value.
|
Tensor: Wav loss value.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean")
|
wav_loss = F.l1_loss(x, x_hat)
|
||||||
|
|
||||||
return wav_loss
|
return wav_loss
|
||||||
|
|
||||||
|
@ -4,16 +4,40 @@ from bisect import bisect_right
|
|||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
# It will be replaced with huggingface optimization
|
||||||
|
class WarmUpLR(_LRScheduler):
|
||||||
|
"""warmup_training learning rate scheduler
|
||||||
|
Args:
|
||||||
|
optimizer: optimzier(e.g. SGD)
|
||||||
|
total_iters: totoal_iters of warmup phase
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optimizer, iter_per_epoch, warmup_epoch, last_epoch=-1):
|
||||||
|
|
||||||
|
self.total_iters = iter_per_epoch * warmup_epoch
|
||||||
|
self.iter_per_epoch = iter_per_epoch
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
"""we will use the first m batches, and set the learning
|
||||||
|
rate to base_lr * m / total_iters
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
base_lr * self.last_epoch / (self.total_iters + 1e-8)
|
||||||
|
for base_lr in self.base_lrs
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class WarmupLrScheduler(_LRScheduler):
|
class WarmupLrScheduler(_LRScheduler):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer,
|
optimizer,
|
||||||
warmup_epoch=500,
|
warmup_iter=500,
|
||||||
warmup_ratio=5e-4,
|
warmup_ratio=5e-4,
|
||||||
warmup="exp",
|
warmup="exp",
|
||||||
last_epoch=-1,
|
last_epoch=-1,
|
||||||
):
|
):
|
||||||
self.warmup_epoch = warmup_epoch
|
self.warmup_iter = warmup_iter
|
||||||
self.warmup_ratio = warmup_ratio
|
self.warmup_ratio = warmup_ratio
|
||||||
self.warmup = warmup
|
self.warmup = warmup
|
||||||
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
|
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
|
||||||
@ -24,7 +48,7 @@ class WarmupLrScheduler(_LRScheduler):
|
|||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
def get_lr_ratio(self):
|
def get_lr_ratio(self):
|
||||||
if self.last_epoch < self.warmup_epoch:
|
if self.last_epoch < self.warmup_iter:
|
||||||
ratio = self.get_warmup_ratio()
|
ratio = self.get_warmup_ratio()
|
||||||
else:
|
else:
|
||||||
ratio = self.get_main_ratio()
|
ratio = self.get_main_ratio()
|
||||||
@ -35,7 +59,7 @@ class WarmupLrScheduler(_LRScheduler):
|
|||||||
|
|
||||||
def get_warmup_ratio(self):
|
def get_warmup_ratio(self):
|
||||||
assert self.warmup in ("linear", "exp")
|
assert self.warmup in ("linear", "exp")
|
||||||
alpha = self.last_epoch / self.warmup_epoch
|
alpha = self.last_epoch / self.warmup_iter
|
||||||
if self.warmup == "linear":
|
if self.warmup == "linear":
|
||||||
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
|
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
|
||||||
elif self.warmup == "exp":
|
elif self.warmup == "exp":
|
||||||
@ -48,22 +72,22 @@ class WarmupPolyLrScheduler(WarmupLrScheduler):
|
|||||||
self,
|
self,
|
||||||
optimizer,
|
optimizer,
|
||||||
power,
|
power,
|
||||||
max_epoch,
|
max_iter,
|
||||||
warmup_epoch=500,
|
warmup_iter=500,
|
||||||
warmup_ratio=5e-4,
|
warmup_ratio=5e-4,
|
||||||
warmup="exp",
|
warmup="exp",
|
||||||
last_epoch=-1,
|
last_epoch=-1,
|
||||||
):
|
):
|
||||||
self.power = power
|
self.power = power
|
||||||
self.max_epoch = max_epoch
|
self.max_iter = max_iter
|
||||||
super(WarmupPolyLrScheduler, self).__init__(
|
super(WarmupPolyLrScheduler, self).__init__(
|
||||||
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_main_ratio(self):
|
def get_main_ratio(self):
|
||||||
real_epoch = self.last_epoch - self.warmup_epoch
|
real_iter = self.last_epoch - self.warmup_iter
|
||||||
real_max_epoch = self.max_epoch - self.warmup_epoch
|
real_max_iter = self.max_iter - self.warmup_iter
|
||||||
alpha = real_epoch / real_max_epoch
|
alpha = real_iter / real_max_iter
|
||||||
ratio = (1 - alpha) ** self.power
|
ratio = (1 - alpha) ** self.power
|
||||||
return ratio
|
return ratio
|
||||||
|
|
||||||
@ -74,7 +98,7 @@ class WarmupExpLrScheduler(WarmupLrScheduler):
|
|||||||
optimizer,
|
optimizer,
|
||||||
gamma,
|
gamma,
|
||||||
interval=1,
|
interval=1,
|
||||||
warmup_epoch=500,
|
warmup_iter=500,
|
||||||
warmup_ratio=5e-4,
|
warmup_ratio=5e-4,
|
||||||
warmup="exp",
|
warmup="exp",
|
||||||
last_epoch=-1,
|
last_epoch=-1,
|
||||||
@ -82,12 +106,12 @@ class WarmupExpLrScheduler(WarmupLrScheduler):
|
|||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
super(WarmupExpLrScheduler, self).__init__(
|
super(WarmupExpLrScheduler, self).__init__(
|
||||||
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_main_ratio(self):
|
def get_main_ratio(self):
|
||||||
real_epoch = self.last_epoch - self.warmup_epoch
|
real_iter = self.last_epoch - self.warmup_iter
|
||||||
ratio = self.gamma ** (real_epoch // self.interval)
|
ratio = self.gamma ** (real_iter // self.interval)
|
||||||
return ratio
|
return ratio
|
||||||
|
|
||||||
|
|
||||||
@ -95,25 +119,26 @@ class WarmupCosineLrScheduler(WarmupLrScheduler):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer,
|
optimizer,
|
||||||
max_epoch,
|
max_iter,
|
||||||
eta_ratio=0,
|
eta_ratio=0,
|
||||||
warmup_epoch=500,
|
warmup_iter=500,
|
||||||
warmup_ratio=5e-4,
|
warmup_ratio=5e-4,
|
||||||
warmup="exp",
|
warmup="exp",
|
||||||
last_epoch=-1,
|
last_epoch=-1,
|
||||||
):
|
):
|
||||||
self.eta_ratio = eta_ratio
|
self.eta_ratio = eta_ratio
|
||||||
self.max_epoch = max_epoch
|
self.max_iter = max_iter
|
||||||
super(WarmupCosineLrScheduler, self).__init__(
|
super(WarmupCosineLrScheduler, self).__init__(
|
||||||
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_main_ratio(self):
|
def get_main_ratio(self):
|
||||||
real_max_epoch = self.max_epoch - self.warmup_epoch
|
real_iter = self.last_epoch - self.warmup_iter
|
||||||
|
real_max_iter = self.max_iter - self.warmup_iter
|
||||||
return (
|
return (
|
||||||
self.eta_ratio
|
self.eta_ratio
|
||||||
+ (1 - self.eta_ratio)
|
+ (1 - self.eta_ratio)
|
||||||
* (1 + math.cos(math.pi * self.last_epoch / real_max_epoch))
|
* (1 + math.cos(math.pi * self.last_epoch / real_max_iter))
|
||||||
/ 2
|
/ 2
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -124,7 +149,7 @@ class WarmupStepLrScheduler(WarmupLrScheduler):
|
|||||||
optimizer,
|
optimizer,
|
||||||
milestones: list,
|
milestones: list,
|
||||||
gamma=0.1,
|
gamma=0.1,
|
||||||
warmup_epoch=500,
|
warmup_iter=500,
|
||||||
warmup_ratio=5e-4,
|
warmup_ratio=5e-4,
|
||||||
warmup="exp",
|
warmup="exp",
|
||||||
last_epoch=-1,
|
last_epoch=-1,
|
||||||
@ -132,10 +157,10 @@ class WarmupStepLrScheduler(WarmupLrScheduler):
|
|||||||
self.milestones = milestones
|
self.milestones = milestones
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
super(WarmupStepLrScheduler, self).__init__(
|
super(WarmupStepLrScheduler, self).__init__(
|
||||||
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
optimizer, warmup_iter, warmup_ratio, warmup, last_epoch
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_main_ratio(self):
|
def get_main_ratio(self):
|
||||||
real_epoch = self.last_epoch - self.warmup_epoch
|
real_iter = self.last_epoch - self.warmup_iter
|
||||||
ratio = self.gamma ** bisect_right(self.milestones, real_epoch)
|
ratio = self.gamma ** bisect_right(self.milestones, real_iter)
|
||||||
return ratio
|
return ratio
|
||||||
|
@ -187,6 +187,7 @@ def get_params() -> AttributeDict:
|
|||||||
"valid_interval": 200,
|
"valid_interval": 200,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"sampling_rate": 24000,
|
"sampling_rate": 24000,
|
||||||
|
"audio_normalization": False,
|
||||||
"chunk_size": 1.0, # in seconds
|
"chunk_size": 1.0, # in seconds
|
||||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||||
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss
|
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss
|
||||||
@ -276,13 +277,13 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
}
|
}
|
||||||
discriminator_params = {
|
discriminator_params = {
|
||||||
"stft_discriminator_n_filters": 32,
|
"stft_discriminator_n_filters": 32,
|
||||||
"discriminator_epoch_start": 3,
|
"discriminator_epoch_start": 5,
|
||||||
"n_ffts": [1024, 2048, 512],
|
"n_ffts": [1024, 2048, 512],
|
||||||
"hop_lengths": [256, 512, 128],
|
"hop_lengths": [256, 512, 128],
|
||||||
"win_lengths": [1024, 2048, 512],
|
"win_lengths": [1024, 2048, 512],
|
||||||
}
|
}
|
||||||
inference_params = {
|
inference_params = {
|
||||||
"target_bw": 12,
|
"target_bw": 6,
|
||||||
}
|
}
|
||||||
|
|
||||||
params.update(generator_params)
|
params.update(generator_params)
|
||||||
@ -353,6 +354,11 @@ def prepare_input(
|
|||||||
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if params.audio_normalization:
|
||||||
|
mean = audio.mean(dim=-1, keepdim=True)
|
||||||
|
std = audio.std(dim=-1, keepdim=True)
|
||||||
|
audio = (audio - mean) / (std + 1e-7)
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens
|
return audio, audio_lens, features, features_lens
|
||||||
|
|
||||||
|
|
||||||
@ -532,6 +538,10 @@ def train_one_epoch(
|
|||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# step per iteration
|
||||||
|
scheduler_g.step()
|
||||||
|
scheduler_d.step()
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1009,16 +1019,16 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
scheduler_g = WarmupCosineLrScheduler(
|
scheduler_g = WarmupCosineLrScheduler(
|
||||||
optimizer=optimizer_g,
|
optimizer=optimizer_g,
|
||||||
max_epoch=params.num_epochs,
|
max_iter=params.num_epochs * 1500,
|
||||||
eta_ratio=0.1,
|
eta_ratio=0.1,
|
||||||
warmup_epoch=params.discriminator_epoch_start,
|
warmup_iter=params.discriminator_epoch_start * 1500,
|
||||||
warmup_ratio=1e-4,
|
warmup_ratio=1e-4,
|
||||||
)
|
)
|
||||||
scheduler_d = WarmupCosineLrScheduler(
|
scheduler_d = WarmupCosineLrScheduler(
|
||||||
optimizer=optimizer_d,
|
optimizer=optimizer_d,
|
||||||
max_epoch=params.num_epochs,
|
max_iter=params.num_epochs * 1500,
|
||||||
eta_ratio=0.1,
|
eta_ratio=0.1,
|
||||||
warmup_epoch=params.discriminator_epoch_start,
|
warmup_iter=params.discriminator_epoch_start * 1500,
|
||||||
warmup_ratio=1e-4,
|
warmup_ratio=1e-4,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1128,10 +1138,6 @@ def run(rank, world_size, args):
|
|||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
# step per epoch
|
|
||||||
scheduler_g.step()
|
|
||||||
scheduler_d.step()
|
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user