mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 02:06:13 +00:00
comments updated
This commit is contained in:
parent
df87a0fe2c
commit
5492a6a5e2
@ -1,8 +1,19 @@
|
||||
# Modified from egs/ljspeech/TTS/vits/loss.py by: Zengrui JIN (Tsinghua University)
|
||||
# original implementation is from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encodec-related loss modules.
|
||||
|
||||
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||
from torchaudio.transforms import MelSpectrogram
|
||||
|
||||
|
||||
@ -225,15 +236,6 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||
self.wav_to_specs = []
|
||||
for i in range(5, 12):
|
||||
s = 2**i
|
||||
# self.wav_to_specs.append(
|
||||
# Wav2LogFilterBank(
|
||||
# sampling_rate=sampling_rate,
|
||||
# frame_length=s,
|
||||
# frame_shift=s // 4,
|
||||
# use_fft_mag=use_fft_mag,
|
||||
# num_filters=n_mels,
|
||||
# )
|
||||
# )
|
||||
self.wav_to_specs.append(
|
||||
MelSpectrogram(
|
||||
sample_rate=sampling_rate,
|
||||
|
@ -186,12 +186,11 @@ def get_params() -> AttributeDict:
|
||||
"env_info": get_env_info(),
|
||||
"sampling_rate": 24000,
|
||||
"audio_normalization": False,
|
||||
"chunk_size": 1.0, # in seconds
|
||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
||||
"lambda_feat": 4.0, # loss scaling coefficient for feat loss
|
||||
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
||||
"lambda_com": 1.0, # loss scaling coefficient for commitment loss
|
||||
"lambda_com": 1000.0, # loss scaling coefficient for commitment loss
|
||||
}
|
||||
)
|
||||
|
||||
@ -342,12 +341,10 @@ def prepare_input(
|
||||
|
||||
if is_training:
|
||||
audio_dims = audio.size(-1)
|
||||
start_idx = random.randint(
|
||||
0, max(0, audio_dims - params.chunk_size * params.sampling_rate)
|
||||
)
|
||||
start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate))
|
||||
audio = audio[:, start_idx : params.sampling_rate + start_idx]
|
||||
else:
|
||||
# NOTE: a very coarse setup
|
||||
# NOTE(zengrui): a very coarse setup
|
||||
audio = audio[
|
||||
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user