using official repo

This commit is contained in:
yuekaiz 2024-12-22 18:48:26 +08:00
parent 604ab6f6b3
commit 511f63b551
2 changed files with 100 additions and 95 deletions

View File

@ -44,7 +44,7 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model.cfm import CFM
from model.dit import DiT
from model.utils import MelSpec
from model.utils import convert_char_to_pinyin
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
@ -151,7 +151,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
type=Path,
default="exp/valle_dev",
help="""The experiment dir.
It specifies the directory where all training related
@ -162,7 +162,7 @@ def get_parser():
parser.add_argument(
"--tokens",
type=str,
default="ft-tts/vocab.txt",
default="f5-tts/vocab.txt",
help="Path to the unique text tokens file",
)
@ -409,7 +409,7 @@ def get_model(params):
def load_pretrained_checkpoint(
model, ckpt_path, device: str = "cpu", dtype=torch.float32
):
model = model.to(dtype)
# model = model.to(dtype)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
checkpoint["model_state_dict"] = {
@ -548,14 +548,15 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
def prepare_input(batch: dict, tokenizer, device: torch.device):
def prepare_input(batch: dict, device: torch.device):
"""Parse batch data"""
print(batch.keys())
print(batch)
text_inputs = batch["text"]
mel_spec = batch["mel"].permute(0, 2, 1)
mel_lengths = batch["mel_lengths"]
return text_inputs, mel_spec, mel_lengths
# texts.extend(convert_char_to_pinyin([text], polyphone=true))
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
print(text_inputs)
mel_spec = batch["features"]
mel_lengths = batch["features_lens"]
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
def compute_loss(
@ -584,34 +585,28 @@ def compute_loss(
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
(mel_spec, text_inputs, mel_lengths) = prepare_input(batch, device)
(text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device)
# at entry, TextTokens is (N, P)
assert text_inputs.ndim == 2
assert mel_spec.ndim == 3
with torch.set_grad_enabled(is_training):
loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths)
assert loss.requires_grad == is_training
print(loss)
# from accelerate import Accelerator
# from accelerate.utils import DistributedDataParallelKwargs
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# accelerator = Accelerator(
# kwargs_handlers=[ddp_kwargs],
# )
# accelerator.backward(loss)
# loss.backward()
info = MetricsTracker()
exit(0)
# with warnings.catch_warnings():
# warnings.simplefilter("ignore")
# info["frames"] = (audio_features_lens).sum().item()
# info["utterances"] = text_tokens.size(0)
# info["samples"] = mel_lengths.size(0)
# # # Note: We use reduction=sum while computing the loss.
# # info["loss"] = loss.detach().cpu().item()
# # for metric in metrics:
# # info[metric] = metrics[metric].detach().cpu().item()
# # del metrics
# # Note: We use reduction=sum while computing the loss.
# info["loss"] = loss.detach().cpu().item() * info["frames"]
# for i in range(len(loss_list)):
# info[f"loss_{i}"] = loss_list[i].detach().cpu().item() * info["frames"]
# for i in range(len(acc_list)):
# info[f"acc_{i}"] = acc_list[i] * info["frames"]
# info["loss"] = loss.detach().cpu().item() * info["samples"]
return loss, info
@ -734,6 +729,7 @@ def train_one_epoch(
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * (
1 / params.reset_interval
@ -753,7 +749,9 @@ def train_one_epoch(
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
for k in range(params.accumulate_grad_steps):
if isinstance(scheduler, Eden):
@ -926,11 +924,6 @@ def run(rank, world_size, args):
logging.info("Training started")
if args.tensorboard and rank == 0:
if params.train_stage:
tb_writer = SummaryWriter(
log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}"
)
else:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
@ -950,7 +943,7 @@ def run(rank, world_size, args):
logging.info("About to create model")
model = get_model(params)
model = load_pretrained_checkpoint(model, params.pretrained_model_path)
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
model = model.to(device)
@ -968,6 +961,7 @@ def run(rank, world_size, args):
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
@ -1029,7 +1023,7 @@ def run(rank, world_size, args):
dataset = TtsDataModule(args)
train_cuts = dataset.train_cuts()
valid_cuts = dataset.dev_cuts()
valid_cuts = dataset.valid_cuts()
train_cuts = filter_short_and_long_utterances(
train_cuts, params.filter_min_duration, params.filter_max_duration
@ -1041,7 +1035,7 @@ def run(rank, world_size, args):
train_dl = dataset.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_dl = dataset.dev_dataloaders(valid_cuts)
valid_dl = dataset.valid_dataloaders(valid_cuts)
if params.oom_check:
scan_pessimistic_batches_for_oom(
@ -1136,7 +1130,7 @@ def scan_pessimistic_batches_for_oom(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
print(23333)
dtype = torch.float32
if params.dtype in ["bfloat16", "bf16"]:
dtype = torch.bfloat16
@ -1145,16 +1139,17 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
print(batch.keys())
try:
with torch.cuda.amp.autocast(dtype=dtype):
_, loss, _ = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
tokenizer=tokenizer,
batch=batch,
is_training=True,
)
loss.backward()
loss.backward(retain_graph=True)
optimizer.zero_grad()
except Exception as e:
if "CUDA out of memory" in str(e):

View File

@ -24,21 +24,22 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
from fbank import MatchaFbank, MatchaFbankConfig
# from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from speech_synthesis import SpeechSynthesisDataset # noqa F401
from torch.utils.data import DataLoader
from icefall.utils import str2bool
@ -174,29 +175,32 @@ class TtsDataModule:
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# train = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
@ -242,26 +246,29 @@ class TtsDataModule:
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# validate = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -286,26 +293,29 @@ class TtsDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# test = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)