mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
remove unused code
This commit is contained in:
parent
7994684bf4
commit
c558328dc5
@ -27,28 +27,100 @@ The generated fbank features are saved in data/fbank.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
load_manifest,
|
||||
load_manifest_lazy,
|
||||
)
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||
from lhotse.audio import RecordingSet
|
||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||
from lhotse.supervision import SupervisionSet
|
||||
from lhotse.utils import Seconds, compute_num_frames
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
@dataclass
|
||||
class MyFbankConfig:
|
||||
n_fft: int
|
||||
n_mels: int
|
||||
sampling_rate: int
|
||||
hop_length: int
|
||||
win_length: int
|
||||
f_min: float
|
||||
f_max: float
|
||||
|
||||
|
||||
@register_extractor
|
||||
class MyFbank(FeatureExtractor):
|
||||
|
||||
name = "MyFbank"
|
||||
config_type = MyFbankConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, torch.device]:
|
||||
return self.config.device
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.n_mels
|
||||
|
||||
def extract(
|
||||
self,
|
||||
samples: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> torch.Tensor:
|
||||
# Check for sampling rate compatibility.
|
||||
expected_sr = self.config.sampling_rate
|
||||
assert sampling_rate == expected_sr, (
|
||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||
f"got {sampling_rate}"
|
||||
)
|
||||
samples = torch.from_numpy(samples)
|
||||
assert samples.ndim == 2, samples.shape
|
||||
assert samples.shape[0] == 1, samples.shape
|
||||
|
||||
mel = (
|
||||
mel_spectrogram(
|
||||
samples,
|
||||
self.config.n_fft,
|
||||
self.config.n_mels,
|
||||
self.config.sampling_rate,
|
||||
self.config.hop_length,
|
||||
self.config.win_length,
|
||||
self.config.f_min,
|
||||
self.config.f_max,
|
||||
center=False,
|
||||
)
|
||||
.squeeze()
|
||||
.t()
|
||||
)
|
||||
|
||||
assert mel.ndim == 2, mel.shape
|
||||
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||
|
||||
num_frames = compute_num_frames(
|
||||
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||
)
|
||||
|
||||
if mel.shape[0] > num_frames:
|
||||
mel = mel[:num_frames]
|
||||
elif mel.shape[0] < num_frames:
|
||||
mel = mel.unsqueeze(0)
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||
).squeeze(0)
|
||||
|
||||
return mel.numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.hop_length / self.config.sampling_rate
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -77,10 +149,15 @@ def compute_fbank_ljspeech(num_jobs: int):
|
||||
logging.info(f"num_jobs: {num_jobs}")
|
||||
logging.info(f"src_dir: {src_dir}")
|
||||
logging.info(f"output_dir: {output_dir}")
|
||||
|
||||
sampling_rate = 22050
|
||||
frame_length = 1024 / sampling_rate # (in second)
|
||||
frame_shift = 256 / sampling_rate # (in second)
|
||||
config = MyFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=22050,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
|
||||
prefix = "ljspeech"
|
||||
suffix = "jsonl.gz"
|
||||
@ -93,25 +170,7 @@ def compute_fbank_ljspeech(num_jobs: int):
|
||||
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
||||
)
|
||||
|
||||
# Differences with matcha-tts
|
||||
# 1. we use pre-emphasis
|
||||
# 2. we remove dc offset
|
||||
# 3. we use a different window
|
||||
# 4. we use a different mel filter bank matrix
|
||||
# 5. we don't normalize features
|
||||
config = FbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
use_fft_mag=True,
|
||||
low_freq=0,
|
||||
high_freq=8000,
|
||||
remove_dc_offset=False,
|
||||
preemph_coeff=0,
|
||||
# should be identical to n_feats in ../matcha/train.py
|
||||
num_filters=80,
|
||||
)
|
||||
extractor = Fbank(config)
|
||||
extractor = MyFbank(config)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
@ -135,6 +194,12 @@ def compute_fbank_ljspeech(num_jobs: int):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
@ -35,6 +35,7 @@ from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset.speech_synthesis import validate_for_tts
|
||||
from compute_fbank_ljspeech import MyFbank
|
||||
|
||||
|
||||
def get_args():
|
||||
|
@ -3,18 +3,17 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import json
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.tokenizer import Tokenizer
|
||||
from matcha.utils.model import fix_len_compatibility
|
||||
@ -355,8 +354,6 @@ def compute_validation_loss(
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
if "tokens" in batch:
|
||||
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
@ -378,13 +375,6 @@ def compute_validation_loss(
|
||||
)
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
else:
|
||||
batch_size = batch["x"].shape[0]
|
||||
batch["x"] = batch["x"].to(device)
|
||||
batch["x_lengths"] = batch["x_lengths"].to(device)
|
||||
batch["y"] = batch["y"].to(device)
|
||||
batch["y_lengths"] = batch["y_lengths"].to(device)
|
||||
losses = get_losses(batch)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
@ -478,7 +468,6 @@ def train_one_epoch(
|
||||
# features_lens, (N,), int32
|
||||
# tokens: List[List[str]], len(tokens) == N
|
||||
|
||||
if "tokens" in batch:
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
(
|
||||
@ -489,11 +478,8 @@ def train_one_epoch(
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device, params)
|
||||
else:
|
||||
batch_size = batch["x"].shape[0]
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
if "tokens" in batch:
|
||||
losses = get_losses(
|
||||
{
|
||||
"x": tokens,
|
||||
@ -504,12 +490,6 @@ def train_one_epoch(
|
||||
"durations": None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
batch["x"] = batch["x"].to(device)
|
||||
batch["x_lengths"] = batch["x_lengths"].to(device)
|
||||
batch["y"] = batch["y"].to(device)
|
||||
batch["y_lengths"] = batch["y_lengths"].to(device)
|
||||
losses = get_losses(batch)
|
||||
|
||||
loss = sum(losses.values())
|
||||
|
||||
@ -535,8 +515,9 @@ def train_one_epoch(
|
||||
raise
|
||||
|
||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# If the grad scale was less than 1, try increasing it.
|
||||
# The _growth_interval of the grad scaler is configurable,
|
||||
# but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
@ -560,7 +541,8 @@ def train_one_epoch(
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||
f"global_batch_idx: {params.batch_idx_train}, "
|
||||
f"batch size: {batch_size}, "
|
||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
@ -588,7 +570,8 @@ def train_one_epoch(
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
"Maximum memory allocated so far is "
|
||||
f"{torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
@ -658,13 +641,6 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create datamodule")
|
||||
|
||||
if False:
|
||||
params.data_args.tokenizer = tokenizer
|
||||
data_module = TextMelDataModule(hparams=params.data_args)
|
||||
del params.data_args.tokenizer
|
||||
train_dl = data_module.train_dataloader()
|
||||
valid_dl = data_module.val_dataloader()
|
||||
else:
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
train_cuts = ljspeech.train_cuts()
|
||||
|
Loading…
x
Reference in New Issue
Block a user