mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
added scheduler w/ warmup
This commit is contained in:
parent
d83ce89fca
commit
58f6562824
@ -60,7 +60,7 @@ class Encodec(nn.Module):
|
|||||||
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
|
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
|
||||||
average_by_discriminators=True, loss_type="hinge"
|
average_by_discriminators=True, loss_type="hinge"
|
||||||
)
|
)
|
||||||
self.feature_match_loss = FeatureLoss(average_by_layers=False)
|
self.feature_match_loss = FeatureLoss()
|
||||||
self.wav_reconstruction_loss = WavReconstructionLoss()
|
self.wav_reconstruction_loss = WavReconstructionLoss()
|
||||||
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
||||||
sampling_rate=self.sampling_rate
|
sampling_rate=self.sampling_rate
|
||||||
|
@ -45,8 +45,8 @@ class GeneratorAdversarialLoss(torch.nn.Module):
|
|||||||
Tensor: Generator adversarial loss value.
|
Tensor: Generator adversarial loss value.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(outputs, (tuple, list)):
|
|
||||||
adv_loss = 0.0
|
adv_loss = 0.0
|
||||||
|
if isinstance(outputs, (tuple, list)):
|
||||||
for i, outputs_ in enumerate(outputs):
|
for i, outputs_ in enumerate(outputs):
|
||||||
if isinstance(outputs_, (tuple, list)):
|
if isinstance(outputs_, (tuple, list)):
|
||||||
# NOTE(kan-bayashi): case including feature maps
|
# NOTE(kan-bayashi): case including feature maps
|
||||||
@ -55,9 +55,10 @@ class GeneratorAdversarialLoss(torch.nn.Module):
|
|||||||
if self.average_by_discriminators:
|
if self.average_by_discriminators:
|
||||||
adv_loss /= i + 1
|
adv_loss /= i + 1
|
||||||
else:
|
else:
|
||||||
adv_loss = self.criterion(outputs)
|
for i, outputs_ in enumerate(outputs):
|
||||||
|
adv_loss += self.criterion(outputs_)
|
||||||
return adv_loss / len(outputs)
|
adv_loss /= i + 1
|
||||||
|
return adv_loss
|
||||||
|
|
||||||
def _mse_loss(self, x):
|
def _mse_loss(self, x):
|
||||||
return F.mse_loss(x, x.new_ones(x.size()))
|
return F.mse_loss(x, x.new_ones(x.size()))
|
||||||
@ -112,9 +113,9 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
|
|||||||
Tensor: Discriminator fake loss value.
|
Tensor: Discriminator fake loss value.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(outputs, (tuple, list)):
|
|
||||||
real_loss = 0.0
|
real_loss = 0.0
|
||||||
fake_loss = 0.0
|
fake_loss = 0.0
|
||||||
|
if isinstance(outputs, (tuple, list)):
|
||||||
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
||||||
if isinstance(outputs_hat_, (tuple, list)):
|
if isinstance(outputs_hat_, (tuple, list)):
|
||||||
# NOTE(kan-bayashi): case including feature maps
|
# NOTE(kan-bayashi): case including feature maps
|
||||||
@ -126,10 +127,13 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
|
|||||||
fake_loss /= i + 1
|
fake_loss /= i + 1
|
||||||
real_loss /= i + 1
|
real_loss /= i + 1
|
||||||
else:
|
else:
|
||||||
real_loss = self.real_criterion(outputs)
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
||||||
fake_loss = self.fake_criterion(outputs_hat)
|
real_loss += self.real_criterion(outputs_)
|
||||||
|
fake_loss += self.fake_criterion(outputs_hat_)
|
||||||
|
fake_loss /= i + 1
|
||||||
|
real_loss /= i + 1
|
||||||
|
|
||||||
return real_loss / len(outputs), fake_loss / len(outputs)
|
return real_loss, fake_loss
|
||||||
|
|
||||||
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return F.mse_loss(x, x.new_ones(x.size()))
|
return F.mse_loss(x, x.new_ones(x.size()))
|
||||||
@ -204,7 +208,7 @@ class FeatureLoss(torch.nn.Module):
|
|||||||
if self.average_by_discriminators:
|
if self.average_by_discriminators:
|
||||||
feat_match_loss /= i + 1
|
feat_match_loss /= i + 1
|
||||||
|
|
||||||
return feat_match_loss / (len(feats) * len(feats[0]))
|
return feat_match_loss
|
||||||
|
|
||||||
|
|
||||||
class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||||
@ -233,7 +237,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
|||||||
self.wav_to_specs.append(
|
self.wav_to_specs.append(
|
||||||
MelSpectrogram(
|
MelSpectrogram(
|
||||||
sample_rate=sampling_rate,
|
sample_rate=sampling_rate,
|
||||||
n_fft=s,
|
n_fft=max(s, 512),
|
||||||
win_length=s,
|
win_length=s,
|
||||||
hop_length=s // 4,
|
hop_length=s // 4,
|
||||||
n_mels=n_mels,
|
n_mels=n_mels,
|
||||||
@ -462,8 +466,15 @@ def loss_g(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
la = FeatureLoss(average_by_layers=False, average_by_discriminators=False)
|
# la = FeatureLoss(average_by_layers=True, average_by_discriminators=True)
|
||||||
aa = [torch.rand(192, 192) for _ in range(3)]
|
# aa = [torch.rand(192, 192) for _ in range(3)]
|
||||||
bb = [torch.rand(192, 192) for _ in range(3)]
|
# bb = [torch.rand(192, 192) for _ in range(3)]
|
||||||
print(la(bb, aa))
|
# print(la(bb, aa))
|
||||||
print(feature_loss(aa, bb))
|
# print(feature_loss(aa, bb))
|
||||||
|
la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge")
|
||||||
|
aa = torch.Tensor([0.1, 0.2, 0.3, 0.4])
|
||||||
|
bb = torch.Tensor([0.4, 0.3, 0.2, 0.1])
|
||||||
|
print(la(aa))
|
||||||
|
print(adversarial_g_loss(aa))
|
||||||
|
print(la(bb))
|
||||||
|
print(adversarial_g_loss(bb))
|
||||||
|
141
egs/libritts/CODEC/encodec/scheduler.py
Normal file
141
egs/libritts/CODEC/encodec/scheduler.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import math
|
||||||
|
from bisect import bisect_right
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupLrScheduler(_LRScheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
warmup_epoch=500,
|
||||||
|
warmup_ratio=5e-4,
|
||||||
|
warmup="exp",
|
||||||
|
last_epoch=-1,
|
||||||
|
):
|
||||||
|
self.warmup_epoch = warmup_epoch
|
||||||
|
self.warmup_ratio = warmup_ratio
|
||||||
|
self.warmup = warmup
|
||||||
|
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
ratio = self.get_lr_ratio()
|
||||||
|
lrs = [ratio * lr for lr in self.base_lrs]
|
||||||
|
return lrs
|
||||||
|
|
||||||
|
def get_lr_ratio(self):
|
||||||
|
if self.last_epoch < self.warmup_epoch:
|
||||||
|
ratio = self.get_warmup_ratio()
|
||||||
|
else:
|
||||||
|
ratio = self.get_main_ratio()
|
||||||
|
return ratio
|
||||||
|
|
||||||
|
def get_main_ratio(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_warmup_ratio(self):
|
||||||
|
assert self.warmup in ("linear", "exp")
|
||||||
|
alpha = self.last_epoch / self.warmup_epoch
|
||||||
|
if self.warmup == "linear":
|
||||||
|
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
|
||||||
|
elif self.warmup == "exp":
|
||||||
|
ratio = self.warmup_ratio ** (1.0 - alpha)
|
||||||
|
return ratio
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupPolyLrScheduler(WarmupLrScheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
power,
|
||||||
|
max_epoch,
|
||||||
|
warmup_epoch=500,
|
||||||
|
warmup_ratio=5e-4,
|
||||||
|
warmup="exp",
|
||||||
|
last_epoch=-1,
|
||||||
|
):
|
||||||
|
self.power = power
|
||||||
|
self.max_epoch = max_epoch
|
||||||
|
super(WarmupPolyLrScheduler, self).__init__(
|
||||||
|
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_main_ratio(self):
|
||||||
|
real_epoch = self.last_epoch - self.warmup_epoch
|
||||||
|
real_max_epoch = self.max_epoch - self.warmup_epoch
|
||||||
|
alpha = real_epoch / real_max_epoch
|
||||||
|
ratio = (1 - alpha) ** self.power
|
||||||
|
return ratio
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupExpLrScheduler(WarmupLrScheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
gamma,
|
||||||
|
interval=1,
|
||||||
|
warmup_epoch=500,
|
||||||
|
warmup_ratio=5e-4,
|
||||||
|
warmup="exp",
|
||||||
|
last_epoch=-1,
|
||||||
|
):
|
||||||
|
self.gamma = gamma
|
||||||
|
self.interval = interval
|
||||||
|
super(WarmupExpLrScheduler, self).__init__(
|
||||||
|
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_main_ratio(self):
|
||||||
|
real_epoch = self.last_epoch - self.warmup_epoch
|
||||||
|
ratio = self.gamma ** (real_epoch // self.interval)
|
||||||
|
return ratio
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupCosineLrScheduler(WarmupLrScheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
max_epoch,
|
||||||
|
eta_ratio=0,
|
||||||
|
warmup_epoch=500,
|
||||||
|
warmup_ratio=5e-4,
|
||||||
|
warmup="exp",
|
||||||
|
last_epoch=-1,
|
||||||
|
):
|
||||||
|
self.eta_ratio = eta_ratio
|
||||||
|
self.max_epoch = max_epoch
|
||||||
|
super(WarmupCosineLrScheduler, self).__init__(
|
||||||
|
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_main_ratio(self):
|
||||||
|
real_max_epoch = self.max_epoch - self.warmup_epoch
|
||||||
|
return (
|
||||||
|
self.eta_ratio
|
||||||
|
+ (1 - self.eta_ratio)
|
||||||
|
* (1 + math.cos(math.pi * self.last_epoch / real_max_epoch))
|
||||||
|
/ 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupStepLrScheduler(WarmupLrScheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
milestones: list,
|
||||||
|
gamma=0.1,
|
||||||
|
warmup_epoch=500,
|
||||||
|
warmup_ratio=5e-4,
|
||||||
|
warmup="exp",
|
||||||
|
last_epoch=-1,
|
||||||
|
):
|
||||||
|
self.milestones = milestones
|
||||||
|
self.gamma = gamma
|
||||||
|
super(WarmupStepLrScheduler, self).__init__(
|
||||||
|
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_main_ratio(self):
|
||||||
|
real_epoch = self.last_epoch - self.warmup_epoch
|
||||||
|
ratio = self.gamma ** bisect_right(self.milestones, real_epoch)
|
||||||
|
return ratio
|
@ -16,6 +16,7 @@ from encodec import Encodec
|
|||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from loss import adopt_weight
|
from loss import adopt_weight
|
||||||
|
from scheduler import WarmupCosineLrScheduler
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -188,10 +189,10 @@ def get_params() -> AttributeDict:
|
|||||||
"sampling_rate": 24000,
|
"sampling_rate": 24000,
|
||||||
"chunk_size": 1.0, # in seconds
|
"chunk_size": 1.0, # in seconds
|
||||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||||
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss
|
||||||
"lambda_feat": 3.0, # loss scaling coefficient for feat loss
|
"lambda_feat": 3.0, # loss scaling coefficient for feat loss
|
||||||
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
||||||
"lambda_com": 1.0, # loss scaling coefficient for commitment loss
|
"lambda_com": 100.0, # loss scaling coefficient for commitment loss
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -260,7 +261,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
# }
|
# }
|
||||||
# discriminator_params = {
|
# discriminator_params = {
|
||||||
# "stft_discriminator_n_filters": 32,
|
# "stft_discriminator_n_filters": 32,
|
||||||
# "discriminator_iter_start": 500,
|
# "discriminator_epoch_start": 5,
|
||||||
# }
|
# }
|
||||||
# inference_params = {
|
# inference_params = {
|
||||||
# "target_bw": 7.5,
|
# "target_bw": 7.5,
|
||||||
@ -275,7 +276,10 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
}
|
}
|
||||||
discriminator_params = {
|
discriminator_params = {
|
||||||
"stft_discriminator_n_filters": 32,
|
"stft_discriminator_n_filters": 32,
|
||||||
"discriminator_iter_start": 500,
|
"discriminator_epoch_start": 3,
|
||||||
|
"n_ffts": [1024, 2048, 512],
|
||||||
|
"hop_lengths": [256, 512, 128],
|
||||||
|
"win_lengths": [1024, 2048, 512],
|
||||||
}
|
}
|
||||||
inference_params = {
|
inference_params = {
|
||||||
"target_bw": 12,
|
"target_bw": 12,
|
||||||
@ -316,7 +320,10 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
multi_scale_discriminator=None,
|
multi_scale_discriminator=None,
|
||||||
multi_period_discriminator=None,
|
multi_period_discriminator=None,
|
||||||
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
|
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
|
||||||
n_filters=params.stft_discriminator_n_filters
|
n_filters=params.stft_discriminator_n_filters,
|
||||||
|
n_ffts=params.n_ffts,
|
||||||
|
hop_lengths=params.hop_lengths,
|
||||||
|
win_lengths=params.win_lengths,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
@ -437,8 +444,8 @@ def train_one_epoch(
|
|||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
d_weight = adopt_weight(
|
d_weight = adopt_weight(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.batch_idx_train,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_iter_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
)
|
)
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
(
|
(
|
||||||
@ -473,8 +480,8 @@ def train_one_epoch(
|
|||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
g_weight = adopt_weight(
|
g_weight = adopt_weight(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.batch_idx_train,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_iter_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
)
|
)
|
||||||
# forward generator
|
# forward generator
|
||||||
(
|
(
|
||||||
@ -507,7 +514,7 @@ def train_one_epoch(
|
|||||||
gen_loss = (
|
gen_loss = (
|
||||||
gen_adv_loss
|
gen_adv_loss
|
||||||
+ reconstruction_loss
|
+ reconstruction_loss
|
||||||
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
+ params.lambda_feat * feature_loss
|
||||||
+ params.lambda_com * commit_loss
|
+ params.lambda_com * commit_loss
|
||||||
)
|
)
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
@ -688,8 +695,8 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
d_weight = adopt_weight(
|
d_weight = adopt_weight(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.batch_idx_train,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_iter_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
)
|
)
|
||||||
|
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
@ -721,8 +728,8 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
g_weight = adopt_weight(
|
g_weight = adopt_weight(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.batch_idx_train,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_iter_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
)
|
)
|
||||||
# forward generator
|
# forward generator
|
||||||
(
|
(
|
||||||
@ -753,7 +760,7 @@ def compute_validation_loss(
|
|||||||
gen_loss = (
|
gen_loss = (
|
||||||
gen_adv_loss
|
gen_adv_loss
|
||||||
+ reconstruction_loss
|
+ reconstruction_loss
|
||||||
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
+ params.lambda_feat * feature_loss
|
||||||
+ params.lambda_com * commit_loss
|
+ params.lambda_com * commit_loss
|
||||||
)
|
)
|
||||||
assert gen_loss.requires_grad is False
|
assert gen_loss.requires_grad is False
|
||||||
@ -831,8 +838,8 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
+ disc_scale_fake_adv_loss
|
+ disc_scale_fake_adv_loss
|
||||||
) * adopt_weight(
|
) * adopt_weight(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.batch_idx_train,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_iter_start,
|
threshold=params.discriminator_train_start,
|
||||||
)
|
)
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
loss_d.backward()
|
loss_d.backward()
|
||||||
@ -859,8 +866,8 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||||
* adopt_weight(
|
* adopt_weight(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.batch_idx_train,
|
0,
|
||||||
threshold=params.discriminator_iter_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
)
|
)
|
||||||
+ (
|
+ (
|
||||||
params.lambda_wav * wav_reconstruction_loss
|
params.lambda_wav * wav_reconstruction_loss
|
||||||
@ -1000,8 +1007,20 @@ def run(rank, world_size, args):
|
|||||||
betas=(0.5, 0.9),
|
betas=(0.5, 0.9),
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
scheduler_g = WarmupCosineLrScheduler(
|
||||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
|
optimizer=optimizer_g,
|
||||||
|
max_epoch=params.num_epochs,
|
||||||
|
eta_ratio=0.1,
|
||||||
|
warmup_epoch=params.discriminator_epoch_start,
|
||||||
|
warmup_ratio=1e-4,
|
||||||
|
)
|
||||||
|
scheduler_d = WarmupCosineLrScheduler(
|
||||||
|
optimizer=optimizer_d,
|
||||||
|
max_epoch=params.num_epochs,
|
||||||
|
eta_ratio=0.1,
|
||||||
|
warmup_epoch=params.discriminator_epoch_start,
|
||||||
|
warmup_ratio=1e-4,
|
||||||
|
)
|
||||||
|
|
||||||
if checkpoints is not None:
|
if checkpoints is not None:
|
||||||
# load state_dict for optimizers
|
# load state_dict for optimizers
|
||||||
|
Loading…
x
Reference in New Issue
Block a user