mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
using official repo
This commit is contained in:
parent
604ab6f6b3
commit
511f63b551
@ -44,7 +44,7 @@ from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model.cfm import CFM
|
||||
from model.dit import DiT
|
||||
from model.utils import MelSpec
|
||||
from model.utils import convert_char_to_pinyin
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
@ -151,7 +151,7 @@ def get_parser():
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
type=Path,
|
||||
default="exp/valle_dev",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
@ -162,7 +162,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="ft-tts/vocab.txt",
|
||||
default="f5-tts/vocab.txt",
|
||||
help="Path to the unique text tokens file",
|
||||
)
|
||||
|
||||
@ -409,7 +409,7 @@ def get_model(params):
|
||||
def load_pretrained_checkpoint(
|
||||
model, ckpt_path, device: str = "cpu", dtype=torch.float32
|
||||
):
|
||||
model = model.to(dtype)
|
||||
# model = model.to(dtype)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
|
||||
checkpoint["model_state_dict"] = {
|
||||
@ -548,14 +548,15 @@ def save_checkpoint(
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def prepare_input(batch: dict, tokenizer, device: torch.device):
|
||||
def prepare_input(batch: dict, device: torch.device):
|
||||
"""Parse batch data"""
|
||||
print(batch.keys())
|
||||
print(batch)
|
||||
text_inputs = batch["text"]
|
||||
mel_spec = batch["mel"].permute(0, 2, 1)
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
return text_inputs, mel_spec, mel_lengths
|
||||
# texts.extend(convert_char_to_pinyin([text], polyphone=true))
|
||||
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
||||
print(text_inputs)
|
||||
mel_spec = batch["features"]
|
||||
mel_lengths = batch["features_lens"]
|
||||
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
@ -584,34 +585,28 @@ def compute_loss(
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
(mel_spec, text_inputs, mel_lengths) = prepare_input(batch, device)
|
||||
(text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device)
|
||||
# at entry, TextTokens is (N, P)
|
||||
assert text_inputs.ndim == 2
|
||||
assert mel_spec.ndim == 3
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths)
|
||||
assert loss.requires_grad == is_training
|
||||
print(loss)
|
||||
# from accelerate import Accelerator
|
||||
# from accelerate.utils import DistributedDataParallelKwargs
|
||||
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# accelerator = Accelerator(
|
||||
# kwargs_handlers=[ddp_kwargs],
|
||||
# )
|
||||
# accelerator.backward(loss)
|
||||
# loss.backward()
|
||||
|
||||
info = MetricsTracker()
|
||||
exit(0)
|
||||
# with warnings.catch_warnings():
|
||||
# warnings.simplefilter("ignore")
|
||||
# info["frames"] = (audio_features_lens).sum().item()
|
||||
# info["utterances"] = text_tokens.size(0)
|
||||
# info["samples"] = mel_lengths.size(0)
|
||||
|
||||
# # # Note: We use reduction=sum while computing the loss.
|
||||
# # info["loss"] = loss.detach().cpu().item()
|
||||
# # for metric in metrics:
|
||||
# # info[metric] = metrics[metric].detach().cpu().item()
|
||||
# # del metrics
|
||||
# # Note: We use reduction=sum while computing the loss.
|
||||
# info["loss"] = loss.detach().cpu().item() * info["frames"]
|
||||
|
||||
# for i in range(len(loss_list)):
|
||||
# info[f"loss_{i}"] = loss_list[i].detach().cpu().item() * info["frames"]
|
||||
# for i in range(len(acc_list)):
|
||||
# info[f"acc_{i}"] = acc_list[i] * info["frames"]
|
||||
# info["loss"] = loss.detach().cpu().item() * info["samples"]
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -734,6 +729,7 @@ def train_one_epoch(
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * (
|
||||
1 / params.reset_interval
|
||||
@ -753,7 +749,9 @@ def train_one_epoch(
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
# optimizer.zero_grad()
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
|
||||
for k in range(params.accumulate_grad_steps):
|
||||
if isinstance(scheduler, Eden):
|
||||
@ -926,12 +924,7 @@ def run(rank, world_size, args):
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
if params.train_stage:
|
||||
tb_writer = SummaryWriter(
|
||||
log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}"
|
||||
)
|
||||
else:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
@ -950,7 +943,7 @@ def run(rank, world_size, args):
|
||||
logging.info("About to create model")
|
||||
|
||||
model = get_model(params)
|
||||
model = load_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
@ -968,6 +961,7 @@ def run(rank, world_size, args):
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
@ -1029,7 +1023,7 @@ def run(rank, world_size, args):
|
||||
|
||||
dataset = TtsDataModule(args)
|
||||
train_cuts = dataset.train_cuts()
|
||||
valid_cuts = dataset.dev_cuts()
|
||||
valid_cuts = dataset.valid_cuts()
|
||||
|
||||
train_cuts = filter_short_and_long_utterances(
|
||||
train_cuts, params.filter_min_duration, params.filter_max_duration
|
||||
@ -1041,7 +1035,7 @@ def run(rank, world_size, args):
|
||||
train_dl = dataset.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
valid_dl = dataset.dev_dataloaders(valid_cuts)
|
||||
valid_dl = dataset.valid_dataloaders(valid_cuts)
|
||||
|
||||
if params.oom_check:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
@ -1136,7 +1130,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||
)
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
|
||||
print(23333)
|
||||
dtype = torch.float32
|
||||
if params.dtype in ["bfloat16", "bf16"]:
|
||||
dtype = torch.bfloat16
|
||||
@ -1145,16 +1139,17 @@ def scan_pessimistic_batches_for_oom(
|
||||
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
print(batch.keys())
|
||||
try:
|
||||
with torch.cuda.amp.autocast(dtype=dtype):
|
||||
_, loss, _ = compute_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
loss.backward()
|
||||
loss.backward(retain_graph=True)
|
||||
optimizer.zero_grad()
|
||||
except Exception as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
|
@ -24,21 +24,22 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from fbank import MatchaFbank, MatchaFbankConfig
|
||||
|
||||
# from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from speech_synthesis import SpeechSynthesisDataset # noqa F401
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
@ -174,29 +175,32 @@ class TtsDataModule:
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
raise NotImplementedError(
|
||||
"On-the-fly feature extraction is not implemented yet."
|
||||
)
|
||||
# sampling_rate = 22050
|
||||
# config = MatchaFbankConfig(
|
||||
# n_fft=1024,
|
||||
# n_mels=80,
|
||||
# sampling_rate=sampling_rate,
|
||||
# hop_length=256,
|
||||
# win_length=1024,
|
||||
# f_min=0,
|
||||
# f_max=8000,
|
||||
# )
|
||||
# train = SpeechSynthesisDataset(
|
||||
# return_text=True,
|
||||
# return_tokens=False,
|
||||
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
# return_cuts=self.args.return_cuts,
|
||||
# )
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
@ -242,26 +246,29 @@ class TtsDataModule:
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
raise NotImplementedError(
|
||||
"On-the-fly feature extraction is not implemented yet."
|
||||
)
|
||||
# sampling_rate = 22050
|
||||
# config = MatchaFbankConfig(
|
||||
# n_fft=1024,
|
||||
# n_mels=80,
|
||||
# sampling_rate=sampling_rate,
|
||||
# hop_length=256,
|
||||
# win_length=1024,
|
||||
# f_min=0,
|
||||
# f_max=8000,
|
||||
# )
|
||||
# validate = SpeechSynthesisDataset(
|
||||
# return_text=True,
|
||||
# return_tokens=False,
|
||||
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
# return_cuts=self.args.return_cuts,
|
||||
# )
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -286,26 +293,29 @@ class TtsDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
raise NotImplementedError(
|
||||
"On-the-fly feature extraction is not implemented yet."
|
||||
)
|
||||
# sampling_rate = 22050
|
||||
# config = MatchaFbankConfig(
|
||||
# n_fft=1024,
|
||||
# n_mels=80,
|
||||
# sampling_rate=sampling_rate,
|
||||
# hop_length=256,
|
||||
# win_length=1024,
|
||||
# f_min=0,
|
||||
# f_max=8000,
|
||||
# )
|
||||
# test = SpeechSynthesisDataset(
|
||||
# return_text=True,
|
||||
# return_tokens=False,
|
||||
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
# return_cuts=self.args.return_cuts,
|
||||
# )
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_text=True,
|
||||
return_tokens=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user