init commit

This commit is contained in:
zr_jin 2024-10-24 22:58:52 +08:00
parent aa9132c82c
commit 32728ddc54
2 changed files with 110 additions and 0 deletions

View File

@ -0,0 +1,55 @@
from io import BytesIO
from pathlib import Path
from typing import NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
from lhotse.audio.backend import AudioBackend, FileObject
from lhotse.utils import Pathlike, Seconds
class WujiEEGBackend(AudioBackend):
def read_audio(
self,
path_or_fd: Union[Pathlike, FileObject],
offset: Seconds = 0.0,
duration: Optional[Seconds] = None,
force_opus_sampling_rate: Optional[int] = None,
) -> Tuple[np.ndarray, int]:
np_arr = np.load(path_or_fd)
sampling_rate = int(np_arr["fs"])
return np_arr["eeg"][offset * sampling_rate : (offset + duration) * sampling_rate], sampling_rate
def is_applicable(self, path_or_fd: Union[Pathlike, FileObject]) -> bool:
return True
def supports_save(self) -> bool:
return False
def save_audio(
self,
dest: Union[str, Path, BytesIO],
src: Union[torch.Tensor, np.ndarray],
sampling_rate: int,
format: Optional[str] = None,
encoding: Optional[str] = None,
) -> None:
raise NotImplementedError("Saving audio is not supported for the WujiEEGBackend.")
def supports_info(self) -> bool:
return True
def info(
self,
path_or_fd: Union[Pathlike, FileObject],
):
np_arr = np.load(path_or_fd)
sampling_rate = int(np_arr["fs"])
return NamedTuple(
channels=1,
frames=np_arr["eeg"].shape[0] // sampling_rate,
samplerate=sampling_rate,
duration=np_arr["eeg"].shape[0] / sampling_rate,
video= None,
)

View File

@ -0,0 +1,55 @@
from pathlib import Path
from backend_np import WujiEEGBackend
from lhotse import CutSet, MonoCut, Recording, SupervisionSegment
from lhotse.audio.backend import set_current_audio_backend
from tqdm import tqdm
set_current_audio_backend(WujiEEGBackend())
SPLIT=Path("/nvme3/wyc/sleep-net-zero/index/sleep_staging/hsp_nsrr.csv")
DATA_DIR=Path("/home/jinzengrui/proj/biofall/egs/tokenizer/CODEC/data/from_wyc")
if __name__ == "__main__":
with open(SPLIT, "r") as f:
csv_lines = f.readlines()
csv_lines = csv_lines[1:]
train_cuts, val_cuts = [], []
for line in tqdm(csv_lines):
line = line.strip()
npz_path, sess_id, duration, split = line.split(",")
duration = float(duration)
npz_path = Path(npz_path)
npz_fname = npz_path.stem.split(".")[0]
audio = Recording.from_file(npz_path, recording_id=f"{sess_id}-{npz_fname}")
cut = MonoCut(
id=f"{sess_id}-{npz_fname}",
start=0.0,
duration=duration,
channel=0,
recording=audio,
supervisions=[
SupervisionSegment(
id=f"{sess_id}-{npz_fname}",
recording_id=f"{sess_id}-{npz_fname}",
start=0.0,
duration=duration,
channel=0,
text="",
language="",
speaker=sess_id,
)
],
)
if split == "train":
train_cuts.append(cut)
elif split == "val":
val_cuts.append(cut)
else:
raise ValueError(f"Unknown split: {split}")
train_cuts = CutSet.from_cuts(cuts=train_cuts)
train_cuts.to_jsonl(DATA_DIR / "train.jsonl.gz")
val_cuts = CutSet.from_cuts(cuts=val_cuts)
val_cuts.to_jsonl(DATA_DIR / "val.jsonl.gz")