mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
init commit for the EEG setup
This commit is contained in:
parent
39b3d1a050
commit
c345a4a572
@ -26,6 +26,7 @@ from loss import (
|
|||||||
FeatureLoss,
|
FeatureLoss,
|
||||||
GeneratorAdversarialLoss,
|
GeneratorAdversarialLoss,
|
||||||
MelSpectrogramReconstructionLoss,
|
MelSpectrogramReconstructionLoss,
|
||||||
|
SpectrogramReconstructionLoss,
|
||||||
WavReconstructionLoss,
|
WavReconstructionLoss,
|
||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -79,7 +80,7 @@ class Encodec(nn.Module):
|
|||||||
)
|
)
|
||||||
self.feature_match_loss = FeatureLoss()
|
self.feature_match_loss = FeatureLoss()
|
||||||
self.wav_reconstruction_loss = WavReconstructionLoss()
|
self.wav_reconstruction_loss = WavReconstructionLoss()
|
||||||
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
self.spec_reconstruction_loss = SpectrogramReconstructionLoss(
|
||||||
sampling_rate=self.sampling_rate
|
sampling_rate=self.sampling_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -170,7 +171,7 @@ class Encodec(nn.Module):
|
|||||||
wav_reconstruction_loss = self.wav_reconstruction_loss(
|
wav_reconstruction_loss = self.wav_reconstruction_loss(
|
||||||
x=speech, x_hat=speech_hat
|
x=speech, x_hat=speech_hat
|
||||||
)
|
)
|
||||||
mel_reconstruction_loss = self.mel_reconstruction_loss(
|
mel_reconstruction_loss = self.spec_reconstruction_loss(
|
||||||
x=speech, x_hat=speech_hat
|
x=speech, x_hat=speech_hat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from typing import List, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torchaudio.transforms import MelSpectrogram
|
from torchaudio.transforms import MelSpectrogram, Spectrogram
|
||||||
|
|
||||||
|
|
||||||
class GeneratorAdversarialLoss(torch.nn.Module):
|
class GeneratorAdversarialLoss(torch.nn.Module):
|
||||||
@ -295,6 +295,80 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
|||||||
return mel_loss
|
return mel_loss
|
||||||
|
|
||||||
|
|
||||||
|
class SpectrogramReconstructionLoss(torch.nn.Module):
|
||||||
|
"""Spec Reconstruction loss designed for EEG signals."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampling_rate: int = 22050,
|
||||||
|
return_spec: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.wav_to_specs = []
|
||||||
|
for i in range(5, 10):
|
||||||
|
s = 2**i // 8
|
||||||
|
self.wav_to_specs.append(
|
||||||
|
Spectrogram(
|
||||||
|
sample_rate=sampling_rate,
|
||||||
|
n_fft=s,
|
||||||
|
win_length=s,
|
||||||
|
hop_length=s // 4,
|
||||||
|
normalized=True,
|
||||||
|
center=False,
|
||||||
|
pad_mode=None,
|
||||||
|
power=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.return_mel = return_spec
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_hat: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
"""Calculate Mel-spectrogram loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
||||||
|
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
||||||
|
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
||||||
|
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
||||||
|
waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Mel-spectrogram loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
mel_loss = 0.0
|
||||||
|
|
||||||
|
for i, wav_to_spec in enumerate(self.wav_to_specs):
|
||||||
|
s = 2 ** (i + 5)
|
||||||
|
wav_to_spec.to(x.device)
|
||||||
|
|
||||||
|
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
||||||
|
mel = wav_to_spec(x.squeeze(1))
|
||||||
|
|
||||||
|
mel_loss += (
|
||||||
|
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
|
||||||
|
+ (
|
||||||
|
(
|
||||||
|
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
|
||||||
|
** 2
|
||||||
|
).mean(dim=-2)
|
||||||
|
** 0.5
|
||||||
|
).mean()
|
||||||
|
)
|
||||||
|
|
||||||
|
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
||||||
|
# mel = self.wav_to_spec(x.squeeze(1))
|
||||||
|
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
|
||||||
|
|
||||||
|
if self.return_mel:
|
||||||
|
return mel_loss, (mel_hat, mel)
|
||||||
|
|
||||||
|
return mel_loss
|
||||||
|
|
||||||
|
|
||||||
class WavReconstructionLoss(torch.nn.Module):
|
class WavReconstructionLoss(torch.nn.Module):
|
||||||
"""Wav Reconstruction loss."""
|
"""Wav Reconstruction loss."""
|
||||||
|
|
||||||
|
@ -32,13 +32,14 @@ import torch.nn as nn
|
|||||||
from codec_datamodule import CodecDataModule
|
from codec_datamodule import CodecDataModule
|
||||||
from encodec import Encodec
|
from encodec import Encodec
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from PIL import Image
|
||||||
from scheduler import WarmupCosineLrScheduler
|
from scheduler import WarmupCosineLrScheduler
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from utils import MetricsTracker, save_checkpoint
|
from utils import MetricsTracker, plot_curve, plot_feature, save_checkpoint
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
@ -155,6 +156,20 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="The sampling rate of the biomarker.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=180,
|
||||||
|
help="The chunk size of the biomarker (in second).",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -202,8 +217,7 @@ def get_params() -> AttributeDict:
|
|||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"valid_interval": 200,
|
"valid_interval": 200,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"sampling_rate": 24000,
|
"wave_normalization": False,
|
||||||
"audio_normalization": False,
|
|
||||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||||
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
||||||
"lambda_feat": 4.0, # loss scaling coefficient for feat loss
|
"lambda_feat": 4.0, # loss scaling coefficient for feat loss
|
||||||
@ -352,27 +366,27 @@ def prepare_input(
|
|||||||
is_training: bool = True,
|
is_training: bool = True,
|
||||||
):
|
):
|
||||||
"""Parse batch data"""
|
"""Parse batch data"""
|
||||||
audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
|
wave = batch["audio"].to(device, memory_format=torch.contiguous_format)
|
||||||
features = batch["features"].to(device, memory_format=torch.contiguous_format)
|
features = batch["features"].to(device, memory_format=torch.contiguous_format)
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
wave_lens = batch["audio_lens"].to(device)
|
||||||
features_lens = batch["features_lens"].to(device)
|
features_lens = batch["features_lens"].to(device)
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
audio_dims = audio.size(-1)
|
audio_dims = wave.size(-1)
|
||||||
start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate))
|
start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate))
|
||||||
audio = audio[:, start_idx : params.sampling_rate + start_idx]
|
wave = wave[:, start_idx : params.sampling_rate + start_idx]
|
||||||
else:
|
else:
|
||||||
# NOTE(zengrui): a very coarse setup
|
# NOTE(zengrui): a very coarse setup
|
||||||
audio = audio[
|
wave = wave[
|
||||||
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
||||||
]
|
]
|
||||||
|
|
||||||
if params.audio_normalization:
|
if params.wave_normalization:
|
||||||
mean = audio.mean(dim=-1, keepdim=True)
|
mean = wave.mean(dim=-1, keepdim=True)
|
||||||
std = audio.std(dim=-1, keepdim=True)
|
std = wave.std(dim=-1, keepdim=True)
|
||||||
audio = (audio - mean) / (std + 1e-7)
|
wave = (wave - mean) / (std + 1e-7)
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens
|
return wave, wave_lens, features, features_lens
|
||||||
|
|
||||||
|
|
||||||
def train_discriminator(weight, global_step, threshold=0, value=0.0):
|
def train_discriminator(weight, global_step, threshold=0, value=0.0):
|
||||||
@ -623,17 +637,29 @@ def train_one_epoch(
|
|||||||
if speech_hat_i.dim() > 1:
|
if speech_hat_i.dim() > 1:
|
||||||
speech_hat_i = speech_hat_i.squeeze(0)
|
speech_hat_i = speech_hat_i.squeeze(0)
|
||||||
speech_i = speech_i.squeeze(0)
|
speech_i = speech_i.squeeze(0)
|
||||||
tb_writer.add_audio(
|
# tb_writer.add_audio(
|
||||||
f"train/speech_hat_",
|
# f"train/speech_hat_",
|
||||||
speech_hat_i,
|
# speech_hat_i,
|
||||||
|
# params.batch_idx_train,
|
||||||
|
# params.sampling_rate,
|
||||||
|
# )
|
||||||
|
# tb_writer.add_audio(
|
||||||
|
# f"train/speech_",
|
||||||
|
# speech_i,
|
||||||
|
# params.batch_idx_train,
|
||||||
|
# params.sampling_rate,
|
||||||
|
# )
|
||||||
|
tb_writer.add_image(
|
||||||
|
"train/speech_hat_",
|
||||||
|
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_image(
|
||||||
f"train/speech_",
|
"train/speech_",
|
||||||
speech_i,
|
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
# tb_writer.add_image(
|
# tb_writer.add_image(
|
||||||
# "train/mel_hat_",
|
# "train/mel_hat_",
|
||||||
@ -675,19 +701,32 @@ def train_one_epoch(
|
|||||||
if speech_hat_i.dim() > 1:
|
if speech_hat_i.dim() > 1:
|
||||||
speech_hat_i = speech_hat_i.squeeze(0)
|
speech_hat_i = speech_hat_i.squeeze(0)
|
||||||
speech_i = speech_i.squeeze(0)
|
speech_i = speech_i.squeeze(0)
|
||||||
tb_writer.add_audio(
|
# tb_writer.add_audio(
|
||||||
|
# f"train/valid_speech_hat_{index}",
|
||||||
|
# speech_hat_i,
|
||||||
|
# params.batch_idx_train,
|
||||||
|
# params.sampling_rate,
|
||||||
|
# )
|
||||||
|
# tb_writer.add_audio(
|
||||||
|
# f"train/valid_speech_{index}",
|
||||||
|
# speech_i,
|
||||||
|
# params.batch_idx_train,
|
||||||
|
# params.sampling_rate,
|
||||||
|
# )
|
||||||
|
tb_writer.add_image(
|
||||||
f"train/valid_speech_hat_{index}",
|
f"train/valid_speech_hat_{index}",
|
||||||
speech_hat_i,
|
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_image(
|
||||||
f"train/valid_speech_{index}",
|
f"train/valid_speech_{index}",
|
||||||
speech_i,
|
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
@ -1164,11 +1203,19 @@ def run(rank, world_size, args):
|
|||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def override_sampling_rate(args) -> int:
|
||||||
|
logging.info(
|
||||||
|
f"Overriding sampling rate from {args.sampling_rate} to {args.sampling_rate * args.chunk_size}"
|
||||||
|
)
|
||||||
|
return int(args.sampling_rate * args.chunk_size)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
CodecDataModule.add_arguments(parser)
|
CodecDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.sampling_rate = override_sampling_rate(args)
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
|
@ -118,6 +118,20 @@ def plot_feature(spectrogram):
|
|||||||
plt.close()
|
plt.close()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes:
|
||||||
|
import io
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
plt.savefig(buf, format="jpeg")
|
||||||
|
buf.seek(0)
|
||||||
|
plt.close()
|
||||||
|
return buf
|
||||||
|
|
||||||
|
|
||||||
class MetricsTracker(collections.defaultdict):
|
class MetricsTracker(collections.defaultdict):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user