diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/config.json b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json new file mode 100644 index 000000000..06aeb51f1 --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json @@ -0,0 +1,27 @@ +{ + "llm_model_name_or_path": "/workspace/slam/icefall_omni/egs/speech_llm/SPEECH2SPEECH/models/Qwen2.5-0.5B-Instruct", + "data_path": ["../emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"], + "bf16": false, + "output_dir": "./exp_zh", + "num_train_epochs": 3, + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, + "gradient_accumulation_steps": 1, + "evaluation_strategy": "steps", + "eval_steps": 1000, + "save_strategy": "steps", + "save_steps": 5000, + "save_total_limit": 100, + "learning_rate": 0.00005, + "weight_decay": 0.01, + "adam_beta2": 0.95, + "warmup_ratio": 0.03, + "lr_scheduler_type": "cosine", + "logging_steps": 100, + "report_to": "wandb", + "model_max_length": 2048, + "gradient_checkpointing": false, + "dataloader_num_workers": 4, + "dataloader_prefetch_factor": 4, + "deepspeed": "ds_config_zero2.json" +} diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json b/egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json new file mode 100644 index 000000000..b0b139598 --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json @@ -0,0 +1,47 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 64, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupCosineLR", + "params": { + "total_num_steps": "auto", + "warmup_min_ratio": 0.03, + "warmup_num_steps": "auto", + "cos_min_ratio": 0.1 + } + }, + + "zero_optimization": { + "stage": 2, + "overlap_comm": false, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto" + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": 1.0, + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt new file mode 100644 index 000000000..09e069d3a --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt @@ -0,0 +1,7 @@ +torch +transformers +wandb +datasets +accelerate>=0.26.0 +deepspeed +flash-attn diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/run.sh b/egs/emilia/TTS/llasa_cosyvoice2_token/run.sh new file mode 100644 index 000000000..a78bba96b --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/run.sh @@ -0,0 +1,4 @@ + +WANDB_KEY=df59308c1f07be8338a87497523163014442d605 # TODO Set YOUR KEY! +wandb login ${WANDB_KEY} +torchrun --nproc_per_node=8 train.py config.json diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/train.py b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py new file mode 100644 index 000000000..159e483d7 --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py @@ -0,0 +1,171 @@ +import json +import os +import random +import sys +from dataclasses import dataclass, field +from functools import partial +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +import transformers +import wandb +from datasets import load_dataset, load_from_disk +from torch.utils.data import DataLoader, Dataset +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + Trainer, + TrainingArguments, +) +from transformers.trainer_pt_utils import LabelSmoother + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + + +@dataclass +class ModelArguments: + llm_model_name_or_path: Optional[str] = field( + default="meta-llama/Llama-3.2-1B-Instruct" + ) + + +@dataclass +class DataArguments: + data_path: List[str] = field( + default=None, + metadata={"help": "Root path(s) to the data. Can be single path or list."}, + ) + + +@dataclass +class CustomTrainingArguments(TrainingArguments): + optim: str = field(default="adamw_torch_fused") + model_max_length: int = field( + default=2048, + metadata={"help": "Maximum sequence length"}, + ) + logging_steps: int = field(default=100, metadata={"help": "Log every X updates"}) + report_to: Optional[str] = field( + default=None, + metadata={"help": "The integration to report the results and logs to."}, + ) + run_name: Optional[str] = field( + default=None, metadata={"help": "The name of the run for logging."} + ) + gradient_checkpointing: bool = field(default=False) + lr_scheduler_type: str = field( + default="cosine", metadata={"help": "The learning rate scheduler to use."} + ) + remove_unused_columns: bool = field(default=False) + + +def data_collator(batch, tokenizer): + speech_generation_start_index = tokenizer.convert_tokens_to_ids( + "<|SPEECH_GENERATION_START|>" + ) + assistant_index = tokenizer.convert_tokens_to_ids("assistant") + input_ids_list = [] + for i, item in enumerate(batch): + text, code = item["text"], item["code"] + message = [ + {"role": "user", "content": f"Convert the text to speech: {text}"}, + {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}, + ] + + input_ids = tokenizer.apply_chat_template( + message, + tokenize=True, + chat_template=TEMPLATE, + ) + + code = [c + 151665 for c in code] + + idx = input_ids.index(speech_generation_start_index) + input_ids = input_ids[:idx] + code + input_ids[idx + 1 :] + if len(input_ids) < 2048: + input_ids_list.append(input_ids) + + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list = [ + input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids)) + for input_ids in input_ids_list + ] + input_ids = torch.tensor(input_ids_list, dtype=torch.int) + attention_mask = input_ids.ne(tokenizer.pad_token_id) + + target_ids = input_ids.clone() + target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID + mask_indices = torch.where(input_ids == assistant_index) + for i in range(mask_indices[0].size(0)): + row = mask_indices[0][i] + col = mask_indices[1][i] + # + 2 to skip: 'assistant', '\n' + target_ids[row, : col + 2] = IGNORE_TOKEN_ID + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": target_ids.to(dtype=torch.int64), + } + + +def main(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, CustomTrainingArguments) + ) + assert len(sys.argv) == 2 and sys.argv[1].endswith(".json") + ( + model_args, + data_args, + training_args, + ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + + is_main_process = training_args.local_rank in [-1, 0] + if training_args.report_to == "wandb" and is_main_process: + wandb.init( + project="llm_tts", + config=training_args.to_sanitized_dict(), + name=training_args.run_name, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_args.llm_model_name_or_path, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path) + new_tokens = [f"<|s_{i}|>" for i in range(6561)] + ["<|SPEECH_GENERATION_START|>"] + num_added_tokens = tokenizer.add_tokens(new_tokens) + + model.resize_token_embeddings(len(tokenizer)) + model.vocab_size = len(tokenizer) + + dataset = load_dataset("json", data_files=data_args.data_path) + dataset = dataset["train"] + train_test_split = dataset.train_test_split(test_size=100, seed=42) + train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] + + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=lambda features: data_collator(features, tokenizer), + ) + + if is_main_process: + trainer.add_callback(transformers.integrations.WandbCallback()) + + trainer.train(resume_from_checkpoint=None) + trainer.save_model(training_args.output_dir) + + +if __name__ == "__main__": + main()