add tts training

This commit is contained in:
root 2025-05-27 00:18:23 -07:00
parent 39700d5c94
commit 1281d7a515

View File

@ -50,6 +50,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from datasets import load_dataset
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
@ -68,7 +69,7 @@ from transformers import (
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader
from train import add_model_arguments, add_training_arguments, get_params, compute_validation_loss, get_model, display_and_save_batch
from train import add_model_arguments, add_training_arguments, get_params, get_model
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
@ -91,12 +92,12 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# parser.add_argument(
# "--loss-type",
# type=str,
# default="ce",
# help="The type of loss to use.",
# )
parser.add_argument(
"--batch-size",
type=int,
default=16,
help="The batch size to use.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
@ -161,7 +162,7 @@ def preprocess(
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def data_collator(batch, tokenizer, cut_off_len=2048):
def data_collator(batch):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch):
speech_tokens.append(item["code"])
@ -176,21 +177,15 @@ def data_collator(batch, tokenizer, cut_off_len=2048):
lang.append(item["language"])
dnsmos.append(item["dnsmos"])
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
return {
"speech_tokens": speech_tokens,
"input_ids": input_ids,
"attention_mask": attention_mask,
"target_ids": target_ids,
"messages": messages,
"durations": durations,
"ids": ids,
"lang": lang,
"dnsmos": dnsmos,
}
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
@ -216,7 +211,10 @@ def compute_loss(
Return a tuple of two elements. The first element is the loss tensor.
"""
device = next(model.parameters()).device
input_ids, attention_mask, target_ids, answer_cosyvoice_speech_token = batch["input_ids"], batch["attention_mask"], batch["target_ids"], batch["speech_tokens"]
messages, answer_cosyvoice_speech_token = batch["messages"], batch["speech_tokens"]
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
with torch.set_grad_enabled(is_training):
(
@ -235,24 +233,51 @@ def compute_loss(
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
feature_lens = batch["supervisions"]["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = len(messages)
# Note: We use reduction=sum while computing the loss.
info["acc"] = acc * len(messages)
info["codec_acc"] = codec_acc * len(messages)
info["codec_topk_acc"] = codec_topk_acc * len(messages)
info["loss"] = loss.detach().cpu().item()
info["acc"] = (
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
info["codec_acc"] = codec_acc * info["frames"]
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
# FIX ME
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
tokenizer: AutoTokenizer,
@ -297,14 +322,14 @@ def train_one_epoch(
be set to 0.
"""
model.train()
model.encoder.eval()
# model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
batch_size = len(batch["durations"])
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
@ -315,7 +340,7 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
model.encoder.eval()
# model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
@ -373,7 +398,6 @@ def train_one_epoch(
model.step()
except: # noqa
display_and_save_batch(batch, params=params)
raise
if batch_idx % params.log_interval == 0:
@ -399,7 +423,7 @@ def train_one_epoch(
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_value = tot_loss["loss"] / tot_loss["frames"]
loss_value = tot_loss["loss"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
@ -421,6 +445,7 @@ def run(rank, world_size, args):
"""
params = get_params()
params.update(vars(args))
params.valid_interval = 2000
fix_random_seed(params.seed)
@ -428,9 +453,7 @@ def run(rank, world_size, args):
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
model, tokenizer = get_model(params)
if torch.cuda.is_available():
device = torch.device("cuda", get_local_rank())
else:
@ -447,36 +470,34 @@ def run(rank, world_size, args):
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
ds = load_dataset(data_path, split="train")
train_test_split = dataset.train_test_split(test_size=1000, seed=42)
# print(params.dataset)
ds = load_dataset(params.dataset, split="train")
# shuffle the dataset
ds = ds.shuffle(seed=42)
train_test_split = ds.train_test_split(test_size=1000, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
# train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"]
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dl = StatefulDataLoader(
train_dataset,
batch_size=2,
batch_size=params.batch_size,
sampler=sampler,
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=lambda features: data_collator(
features, tokenizer
),
num_workers=4,
prefetch_factor=2,
collate_fn=data_collator
)
train_dl.load_state_dict(sampler_state_dict)
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
valid_dl = DataLoader(
eval_dataset,
batch_size=2,
batch_size=params.batch_size,
sampler=valid_sampler,
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=lambda features: data_collator(
features
),
collate_fn=data_collator
)
if args.tensorboard and rank == 0:
@ -533,7 +554,6 @@ def run(rank, world_size, args):
logging.info("Done!")
def main():
parser = get_parser()
args = parser.parse_args()