add training codes

This commit is contained in:
root 2025-02-28 02:08:05 +00:00
parent 540430d213
commit fa6587010e
5 changed files with 256 additions and 0 deletions

View File

@ -0,0 +1,27 @@
{
"llm_model_name_or_path": "/workspace/slam/icefall_omni/egs/speech_llm/SPEECH2SPEECH/models/Qwen2.5-0.5B-Instruct",
"data_path": ["../emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"],
"bf16": false,
"output_dir": "./exp_zh",
"num_train_epochs": 3,
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"gradient_accumulation_steps": 1,
"evaluation_strategy": "steps",
"eval_steps": 1000,
"save_strategy": "steps",
"save_steps": 5000,
"save_total_limit": 100,
"learning_rate": 0.00005,
"weight_decay": 0.01,
"adam_beta2": 0.95,
"warmup_ratio": 0.03,
"lr_scheduler_type": "cosine",
"logging_steps": 100,
"report_to": "wandb",
"model_max_length": 2048,
"gradient_checkpointing": false,
"dataloader_num_workers": 4,
"dataloader_prefetch_factor": 4,
"deepspeed": "ds_config_zero2.json"
}

View File

@ -0,0 +1,47 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 64,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupCosineLR",
"params": {
"total_num_steps": "auto",
"warmup_min_ratio": 0.03,
"warmup_num_steps": "auto",
"cos_min_ratio": 0.1
}
},
"zero_optimization": {
"stage": 2,
"overlap_comm": false,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,7 @@
torch
transformers
wandb
datasets
accelerate>=0.26.0
deepspeed
flash-attn

View File

@ -0,0 +1,4 @@
WANDB_KEY=df59308c1f07be8338a87497523163014442d605 # TODO Set YOUR KEY!
wandb login ${WANDB_KEY}
torchrun --nproc_per_node=8 train.py config.json

View File

@ -0,0 +1,171 @@
import json
import os
import random
import sys
from dataclasses import dataclass, field
from functools import partial
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import transformers
import wandb
from datasets import load_dataset, load_from_disk
from torch.utils.data import DataLoader, Dataset
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Trainer,
TrainingArguments,
)
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
@dataclass
class ModelArguments:
llm_model_name_or_path: Optional[str] = field(
default="meta-llama/Llama-3.2-1B-Instruct"
)
@dataclass
class DataArguments:
data_path: List[str] = field(
default=None,
metadata={"help": "Root path(s) to the data. Can be single path or list."},
)
@dataclass
class CustomTrainingArguments(TrainingArguments):
optim: str = field(default="adamw_torch_fused")
model_max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length"},
)
logging_steps: int = field(default=100, metadata={"help": "Log every X updates"})
report_to: Optional[str] = field(
default=None,
metadata={"help": "The integration to report the results and logs to."},
)
run_name: Optional[str] = field(
default=None, metadata={"help": "The name of the run for logging."}
)
gradient_checkpointing: bool = field(default=False)
lr_scheduler_type: str = field(
default="cosine", metadata={"help": "The learning rate scheduler to use."}
)
remove_unused_columns: bool = field(default=False)
def data_collator(batch, tokenizer):
speech_generation_start_index = tokenizer.convert_tokens_to_ids(
"<|SPEECH_GENERATION_START|>"
)
assistant_index = tokenizer.convert_tokens_to_ids("assistant")
input_ids_list = []
for i, item in enumerate(batch):
text, code = item["text"], item["code"]
message = [
{"role": "user", "content": f"Convert the text to speech: {text}"},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
]
input_ids = tokenizer.apply_chat_template(
message,
tokenize=True,
chat_template=TEMPLATE,
)
code = [c + 151665 for c in code]
idx = input_ids.index(speech_generation_start_index)
input_ids = input_ids[:idx] + code + input_ids[idx + 1 :]
if len(input_ids) < 2048:
input_ids_list.append(input_ids)
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list = [
input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids))
for input_ids in input_ids_list
]
input_ids = torch.tensor(input_ids_list, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
mask_indices = torch.where(input_ids == assistant_index)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": target_ids.to(dtype=torch.int64),
}
def main():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, CustomTrainingArguments)
)
assert len(sys.argv) == 2 and sys.argv[1].endswith(".json")
(
model_args,
data_args,
training_args,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
is_main_process = training_args.local_rank in [-1, 0]
if training_args.report_to == "wandb" and is_main_process:
wandb.init(
project="llm_tts",
config=training_args.to_sanitized_dict(),
name=training_args.run_name,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.llm_model_name_or_path,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path)
new_tokens = [f"<|s_{i}|>" for i in range(6561)] + ["<|SPEECH_GENERATION_START|>"]
num_added_tokens = tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(tokenizer))
model.vocab_size = len(tokenizer)
dataset = load_dataset("json", data_files=data_args.data_path)
dataset = dataset["train"]
train_test_split = dataset.train_test_split(test_size=100, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=lambda features: data_collator(features, tokenizer),
)
if is_main_process:
trainer.add_callback(transformers.integrations.WandbCallback())
trainer.train(resume_from_checkpoint=None)
trainer.save_model(training_args.output_dir)
if __name__ == "__main__":
main()