From 1281d7a515b0bfbdd7e7e4e926f4d0c644af1448 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 27 May 2025 00:18:23 -0700 Subject: [PATCH] add tts training --- .../SPEECH2SPEECH/qwen_omni/train_tts.py | 120 ++++++++++-------- 1 file changed, 70 insertions(+), 50 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py index 8fd6609a4..38132e71e 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py @@ -50,6 +50,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn import transformers +from datasets import load_dataset from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from label_smoothing import LabelSmoothingLoss @@ -68,7 +69,7 @@ from transformers import ( from torchdata.stateful_dataloader import StatefulDataLoader from torch.utils.data import DistributedSampler, DataLoader -from train import add_model_arguments, add_training_arguments, get_params, compute_validation_loss, get_model, display_and_save_batch +from train import add_model_arguments, add_training_arguments, get_params, get_model from utils import ( # filter_uneven_sized_batch, AttributeDict, MetricsTracker, @@ -91,12 +92,12 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - # parser.add_argument( - # "--loss-type", - # type=str, - # default="ce", - # help="The type of loss to use.", - # ) + parser.add_argument( + "--batch-size", + type=int, + default=16, + help="The batch size to use.", + ) parser = deepspeed.add_config_arguments(parser) add_model_arguments(parser) @@ -161,7 +162,7 @@ def preprocess( attention_mask = input_ids.ne(tokenizer.pad_token_id) return input_ids, attention_mask, target_ids -def data_collator(batch, tokenizer, cut_off_len=2048): +def data_collator(batch): speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], [] for i, item in enumerate(batch): speech_tokens.append(item["code"]) @@ -176,21 +177,15 @@ def data_collator(batch, tokenizer, cut_off_len=2048): lang.append(item["language"]) dnsmos.append(item["dnsmos"]) - input_ids, attention_mask, target_ids = preprocess(messages, tokenizer) - target_ids = target_ids.type(torch.LongTensor) - input_ids = input_ids.type(torch.LongTensor) - return { "speech_tokens": speech_tokens, - "input_ids": input_ids, - "attention_mask": attention_mask, - "target_ids": target_ids, + "messages": messages, "durations": durations, "ids": ids, "lang": lang, "dnsmos": dnsmos, } - + def compute_loss( params: AttributeDict, tokenizer: AutoTokenizer, @@ -216,7 +211,10 @@ def compute_loss( Return a tuple of two elements. The first element is the loss tensor. """ device = next(model.parameters()).device - input_ids, attention_mask, target_ids, answer_cosyvoice_speech_token = batch["input_ids"], batch["attention_mask"], batch["target_ids"], batch["speech_tokens"] + messages, answer_cosyvoice_speech_token = batch["messages"], batch["speech_tokens"] + input_ids, attention_mask, target_ids = preprocess(messages, tokenizer) + target_ids = target_ids.type(torch.LongTensor) + input_ids = input_ids.type(torch.LongTensor) with torch.set_grad_enabled(is_training): ( @@ -235,24 +233,51 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - feature_lens = batch["supervisions"]["num_frames"] - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - + info["frames"] = len(messages) # Note: We use reduction=sum while computing the loss. + info["acc"] = acc * len(messages) + info["codec_acc"] = codec_acc * len(messages) + info["codec_topk_acc"] = codec_topk_acc * len(messages) info["loss"] = loss.detach().cpu().item() - info["acc"] = ( - acc * info["frames"] - ) # WAR: to avoid normalization by the number of frames - - info["codec_acc"] = codec_acc * info["frames"] - info["codec_topk_acc"] = codec_topk_acc * info["frames"] info["codec_loss"] = codec_loss.detach().cpu().item() info["text_loss"] = text_loss.detach().cpu().item() return loss, info +def compute_validation_loss( + params: AttributeDict, + tokenizer: AutoTokenizer, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + with torch.amp.autocast("cuda", enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + # FIX ME + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + def train_one_epoch( params: AttributeDict, tokenizer: AutoTokenizer, @@ -297,14 +322,14 @@ def train_one_epoch( be set to 0. """ model.train() - model.encoder.eval() + # model.encoder.eval() if not params.unfreeze_llm: model.llm.eval() tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) + batch_size = len(batch["durations"]) if batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") valid_info = compute_validation_loss( @@ -315,7 +340,7 @@ def train_one_epoch( world_size=world_size, ) model.train() - model.encoder.eval() + # model.encoder.eval() if not params.unfreeze_llm: model.llm.eval() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") @@ -373,7 +398,6 @@ def train_one_epoch( model.step() except: # noqa - display_and_save_batch(batch, params=params) raise if batch_idx % params.log_interval == 0: @@ -399,7 +423,7 @@ def train_one_epoch( ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - loss_value = tot_loss["loss"] / tot_loss["frames"] + loss_value = tot_loss["loss"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -421,6 +445,7 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) + params.valid_interval = 2000 fix_random_seed(params.seed) @@ -428,9 +453,7 @@ def run(rank, world_size, args): setup_logger(f"{params.exp_dir}/log/log-train") logging.info(params) logging.info("About to create model") - model, tokenizer = get_model(params) - if torch.cuda.is_available(): device = torch.device("cuda", get_local_rank()) else: @@ -447,36 +470,34 @@ def run(rank, world_size, args): sampler_state_dict = None if params.sampler_state_dict_path: sampler_state_dict = torch.load(params.sampler_state_dict_path) - - data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en" - ds = load_dataset(data_path, split="train") - train_test_split = dataset.train_test_split(test_size=1000, seed=42) + # print(params.dataset) + ds = load_dataset(params.dataset, split="train") + # shuffle the dataset + ds = ds.shuffle(seed=42) + train_test_split = ds.train_test_split(test_size=1000, seed=42) train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] + # train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"] sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) train_dl = StatefulDataLoader( train_dataset, - batch_size=2, + batch_size=params.batch_size, sampler=sampler, shuffle=False, - num_workers=1, - prefetch_factor=1, - collate_fn=lambda features: data_collator( - features, tokenizer - ), + num_workers=4, + prefetch_factor=2, + collate_fn=data_collator ) train_dl.load_state_dict(sampler_state_dict) valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank) valid_dl = DataLoader( eval_dataset, - batch_size=2, + batch_size=params.batch_size, sampler=valid_sampler, shuffle=False, num_workers=1, prefetch_factor=1, - collate_fn=lambda features: data_collator( - features - ), + collate_fn=data_collator ) if args.tensorboard and rank == 0: @@ -533,7 +554,6 @@ def run(rank, world_size, args): logging.info("Done!") - def main(): parser = get_parser() args = parser.parse_args()