mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
init commit for the EEG setup
This commit is contained in:
parent
39b3d1a050
commit
c345a4a572
@ -26,6 +26,7 @@ from loss import (
|
||||
FeatureLoss,
|
||||
GeneratorAdversarialLoss,
|
||||
MelSpectrogramReconstructionLoss,
|
||||
SpectrogramReconstructionLoss,
|
||||
WavReconstructionLoss,
|
||||
)
|
||||
from torch import nn
|
||||
@ -79,7 +80,7 @@ class Encodec(nn.Module):
|
||||
)
|
||||
self.feature_match_loss = FeatureLoss()
|
||||
self.wav_reconstruction_loss = WavReconstructionLoss()
|
||||
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
||||
self.spec_reconstruction_loss = SpectrogramReconstructionLoss(
|
||||
sampling_rate=self.sampling_rate
|
||||
)
|
||||
|
||||
@ -170,7 +171,7 @@ class Encodec(nn.Module):
|
||||
wav_reconstruction_loss = self.wav_reconstruction_loss(
|
||||
x=speech, x_hat=speech_hat
|
||||
)
|
||||
mel_reconstruction_loss = self.mel_reconstruction_loss(
|
||||
mel_reconstruction_loss = self.spec_reconstruction_loss(
|
||||
x=speech, x_hat=speech_hat
|
||||
)
|
||||
|
||||
|
@ -14,7 +14,7 @@ from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchaudio.transforms import MelSpectrogram
|
||||
from torchaudio.transforms import MelSpectrogram, Spectrogram
|
||||
|
||||
|
||||
class GeneratorAdversarialLoss(torch.nn.Module):
|
||||
@ -295,6 +295,80 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||
return mel_loss
|
||||
|
||||
|
||||
class SpectrogramReconstructionLoss(torch.nn.Module):
|
||||
"""Spec Reconstruction loss designed for EEG signals."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate: int = 22050,
|
||||
return_spec: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.wav_to_specs = []
|
||||
for i in range(5, 10):
|
||||
s = 2**i // 8
|
||||
self.wav_to_specs.append(
|
||||
Spectrogram(
|
||||
sample_rate=sampling_rate,
|
||||
n_fft=s,
|
||||
win_length=s,
|
||||
hop_length=s // 4,
|
||||
normalized=True,
|
||||
center=False,
|
||||
pad_mode=None,
|
||||
power=None,
|
||||
)
|
||||
)
|
||||
self.return_mel = return_spec
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_hat: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Calculate Mel-spectrogram loss.
|
||||
|
||||
Args:
|
||||
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
||||
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
||||
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
||||
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
||||
waveform.
|
||||
|
||||
Returns:
|
||||
Tensor: Mel-spectrogram loss value.
|
||||
|
||||
"""
|
||||
mel_loss = 0.0
|
||||
|
||||
for i, wav_to_spec in enumerate(self.wav_to_specs):
|
||||
s = 2 ** (i + 5)
|
||||
wav_to_spec.to(x.device)
|
||||
|
||||
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
||||
mel = wav_to_spec(x.squeeze(1))
|
||||
|
||||
mel_loss += (
|
||||
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
|
||||
+ (
|
||||
(
|
||||
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
|
||||
** 2
|
||||
).mean(dim=-2)
|
||||
** 0.5
|
||||
).mean()
|
||||
)
|
||||
|
||||
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
||||
# mel = self.wav_to_spec(x.squeeze(1))
|
||||
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
|
||||
|
||||
if self.return_mel:
|
||||
return mel_loss, (mel_hat, mel)
|
||||
|
||||
return mel_loss
|
||||
|
||||
|
||||
class WavReconstructionLoss(torch.nn.Module):
|
||||
"""Wav Reconstruction loss."""
|
||||
|
||||
|
@ -32,13 +32,14 @@ import torch.nn as nn
|
||||
from codec_datamodule import CodecDataModule
|
||||
from encodec import Encodec
|
||||
from lhotse.utils import fix_random_seed
|
||||
from PIL import Image
|
||||
from scheduler import WarmupCosineLrScheduler
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils import MetricsTracker, save_checkpoint
|
||||
from utils import MetricsTracker, plot_curve, plot_feature, save_checkpoint
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -155,6 +156,20 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=128,
|
||||
help="The sampling rate of the biomarker.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=180,
|
||||
help="The chunk size of the biomarker (in second).",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -202,8 +217,7 @@ def get_params() -> AttributeDict:
|
||||
"log_interval": 50,
|
||||
"valid_interval": 200,
|
||||
"env_info": get_env_info(),
|
||||
"sampling_rate": 24000,
|
||||
"audio_normalization": False,
|
||||
"wave_normalization": False,
|
||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
||||
"lambda_feat": 4.0, # loss scaling coefficient for feat loss
|
||||
@ -352,27 +366,27 @@ def prepare_input(
|
||||
is_training: bool = True,
|
||||
):
|
||||
"""Parse batch data"""
|
||||
audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
|
||||
wave = batch["audio"].to(device, memory_format=torch.contiguous_format)
|
||||
features = batch["features"].to(device, memory_format=torch.contiguous_format)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
wave_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
|
||||
if is_training:
|
||||
audio_dims = audio.size(-1)
|
||||
audio_dims = wave.size(-1)
|
||||
start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate))
|
||||
audio = audio[:, start_idx : params.sampling_rate + start_idx]
|
||||
wave = wave[:, start_idx : params.sampling_rate + start_idx]
|
||||
else:
|
||||
# NOTE(zengrui): a very coarse setup
|
||||
audio = audio[
|
||||
wave = wave[
|
||||
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
||||
]
|
||||
|
||||
if params.audio_normalization:
|
||||
mean = audio.mean(dim=-1, keepdim=True)
|
||||
std = audio.std(dim=-1, keepdim=True)
|
||||
audio = (audio - mean) / (std + 1e-7)
|
||||
if params.wave_normalization:
|
||||
mean = wave.mean(dim=-1, keepdim=True)
|
||||
std = wave.std(dim=-1, keepdim=True)
|
||||
wave = (wave - mean) / (std + 1e-7)
|
||||
|
||||
return audio, audio_lens, features, features_lens
|
||||
return wave, wave_lens, features, features_lens
|
||||
|
||||
|
||||
def train_discriminator(weight, global_step, threshold=0, value=0.0):
|
||||
@ -623,17 +637,29 @@ def train_one_epoch(
|
||||
if speech_hat_i.dim() > 1:
|
||||
speech_hat_i = speech_hat_i.squeeze(0)
|
||||
speech_i = speech_i.squeeze(0)
|
||||
tb_writer.add_audio(
|
||||
f"train/speech_hat_",
|
||||
speech_hat_i,
|
||||
# tb_writer.add_audio(
|
||||
# f"train/speech_hat_",
|
||||
# speech_hat_i,
|
||||
# params.batch_idx_train,
|
||||
# params.sampling_rate,
|
||||
# )
|
||||
# tb_writer.add_audio(
|
||||
# f"train/speech_",
|
||||
# speech_i,
|
||||
# params.batch_idx_train,
|
||||
# params.sampling_rate,
|
||||
# )
|
||||
tb_writer.add_image(
|
||||
"train/speech_hat_",
|
||||
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
dataformats="HWC",
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
f"train/speech_",
|
||||
speech_i,
|
||||
tb_writer.add_image(
|
||||
"train/speech_",
|
||||
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
dataformats="HWC",
|
||||
)
|
||||
# tb_writer.add_image(
|
||||
# "train/mel_hat_",
|
||||
@ -675,19 +701,32 @@ def train_one_epoch(
|
||||
if speech_hat_i.dim() > 1:
|
||||
speech_hat_i = speech_hat_i.squeeze(0)
|
||||
speech_i = speech_i.squeeze(0)
|
||||
tb_writer.add_audio(
|
||||
# tb_writer.add_audio(
|
||||
# f"train/valid_speech_hat_{index}",
|
||||
# speech_hat_i,
|
||||
# params.batch_idx_train,
|
||||
# params.sampling_rate,
|
||||
# )
|
||||
# tb_writer.add_audio(
|
||||
# f"train/valid_speech_{index}",
|
||||
# speech_i,
|
||||
# params.batch_idx_train,
|
||||
# params.sampling_rate,
|
||||
# )
|
||||
tb_writer.add_image(
|
||||
f"train/valid_speech_hat_{index}",
|
||||
speech_hat_i,
|
||||
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
dataformats="HWC",
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
tb_writer.add_image(
|
||||
f"train/valid_speech_{index}",
|
||||
speech_i,
|
||||
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
@ -1164,11 +1203,19 @@ def run(rank, world_size, args):
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def override_sampling_rate(args) -> int:
|
||||
logging.info(
|
||||
f"Overriding sampling rate from {args.sampling_rate} to {args.sampling_rate * args.chunk_size}"
|
||||
)
|
||||
return int(args.sampling_rate * args.chunk_size)
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
CodecDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.sampling_rate = override_sampling_rate(args)
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
|
@ -118,6 +118,20 @@ def plot_feature(spectrogram):
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes:
|
||||
import io
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
plt.figure()
|
||||
plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T)
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="jpeg")
|
||||
buf.seek(0)
|
||||
plt.close()
|
||||
return buf
|
||||
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
def __init__(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user