add loss type

This commit is contained in:
root 2025-05-19 01:31:21 +00:00
parent e52581e69b
commit 4a29430349
4 changed files with 367 additions and 83 deletions

View File

@ -239,7 +239,8 @@ fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "stage 14: Client"
datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa)
datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
datasets=(openbookqa commoneval)
for dataset in ${datasets[@]}; do
# sd-qa should use usa split
if [ $dataset == "sd-qa" ]; then
@ -250,17 +251,16 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
echo $dataset $split_name
python3 ./qwen_omni/client.py \
--subset-name $dataset --split-name $split_name \
--output-dir test_result
--output-dir result_adapter_librispeech_kl_div_qa_template
done
fi
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
log "stage 15: Training Speech2Speech Model, adaptor only"
exp_dir=./qwen_omni/exp_speech2text
ngpu=2
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 600 \
--max-duration 700 \
--enable-musan False \
--audio-key audio --text-key continuation \
--exp-dir $exp_dir \
@ -271,7 +271,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--dataset-format speech_continuation \
--start-epoch 2 --pretrained-model-path $exp_dir/epoch-1/pytorch_model.bin \
--start-epoch 4 --pretrained-model-path $exp_dir/epoch-3/pytorch_model.bin \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
fi
@ -321,3 +321,67 @@ if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
$train_cmd_args
fi
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
log "stage 17: Server for adapter only speech continuation"
exp_dir=./qwen_omni/exp_speech2text
python3 ./qwen_omni/server.py \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-6/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output False \
--use-lora False --prompt-template continuation
fi
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
log "stage 18: Training kl-div Speech2Speech Model, adaptor only"
exp_dir=./qwen_omni/exp_speech2text_kl
ngpu=2
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 700 \
--enable-musan False \
--audio-key audio --text-key continuation \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--on-the-fly-feats True \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--dataset-format speech_continuation \
--loss-type kl_div --dataset librispeech \
--pretrained-model-path $exp_dir/checkpoint-1001/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-1001/sampler.pt \
--use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False
fi
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
log "stage 19: Server for kl loss"
exp_dir=./qwen_omni/exp_speech2text_kl
python3 ./qwen_omni/server.py \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output False \
--use-lora False --prompt-template qa
fi
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "stage 20: Training Speech2Speech Model, adaptor + lora, second stage"
exp_dir=./qwen_omni/exp_speech2text_kl_llm
pretrained_dir=./qwen_omni/exp_speech2text_kl
ngpu=2
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 200 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--pretrained-model-path $pretrained_dir/epoch-10/pytorch_model.bin \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output False --dataset-format vocalnet
fi

View File

@ -64,6 +64,8 @@ class SPEECH_LLM(nn.Module):
encoder_projector: nn.Module,
codec_lm: nn.Module = None,
codec_lm_padding_side: str = "left",
teacher_llm: nn.Module = None,
kl_temperature: float = 2.0,
):
super().__init__()
self.encoder = encoder
@ -92,6 +94,9 @@ class SPEECH_LLM(nn.Module):
multidim_average="global",
ignore_index=IGNORE_TOKEN_ID,
)
if teacher_llm is not None:
self.teacher_llm = teacher_llm
self.kl_temperature = kl_temperature
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
@ -256,6 +261,67 @@ class SPEECH_LLM(nn.Module):
)
return model_outputs.loss, acc
def forward_kl_div(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
teacher_input_ids: torch.LongTensor = None,
teacher_attention_mask: torch.Tensor = None,
teacher_labels: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
)
teacher_outputs = self.teacher_llm(
input_ids=teacher_input_ids,
attention_mask=teacher_attention_mask,
)
kl_loss = torch.nn.functional.kl_div(
torch.nn.functional.log_softmax(
model_outputs.logits[labels != -100] / self.kl_temperature,
dim=-1,
),
torch.nn.functional.softmax(
teacher_outputs.logits[teacher_labels != -100] / self.kl_temperature,
dim=-1,
),
reduction="batchmean",
)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
teacher_preds = torch.argmax(teacher_outputs.logits, -1)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
acc_teacher = compute_accuracy(
teacher_preds.detach()[:, :-1],
teacher_labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
return kl_loss, acc, acc_teacher
def forward_with_speech_output(
self,
fbank: torch.Tensor = None,

View File

@ -21,6 +21,12 @@ def get_args():
default=None,
help="Checkpoint name or path, default to %(default)r",
)
parser.add_argument(
"--prompt-template",
type=str,
default=None,
help="Prompt template",
)
add_model_arguments(parser)
args = parser.parse_args()
return args
@ -59,8 +65,23 @@ model, tokenizer = get_model(args)
app = FastAPI()
device = torch.device("cuda")
if args.prompt_template is None:
template = f"{DEFAULT_SPEECH_TOKEN}"
elif args.prompt_template == "qa":
template = f"Answer the following question:\n\n{DEFAULT_SPEECH_TOKEN}"
elif args.prompt_template == "continuation":
template = f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}"
elif args.prompt_template == "asr":
template = (
f"Repeat the following text, without any explanation: {DEFAULT_SPEECH_TOKEN}"
)
elif args.prompt_template == "mt":
template = f"Please translate the text to Chinese. Your response should only include the Chinese translation, without any additional words:\n\n{DEFAULT_SPEECH_TOKEN}"
else:
raise ValueError(f"Invalid prompt template: {args.prompt_template}")
print("Using template:", template)
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "user", "content": template},
{"role": "assistant", "content": ""},
]
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"

View File

@ -74,8 +74,8 @@ from transformers import (
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
get_rank,
get_local_rank,
get_rank,
get_world_size,
setup_logger,
str2bool,
@ -234,6 +234,21 @@ def get_parser():
default="slam_omni",
help="The format of the dataset.",
)
parser.add_argument(
"--dataset",
type=str,
default="multi_en",
help="The name of the dataset.",
)
parser.add_argument(
"--loss-type",
type=str,
default="ce",
help="The type of loss to use.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
@ -335,6 +350,22 @@ def process_batch_vocalnet(batch: dict):
return messages, answer_cosyvoice_speech_token
def process_batch_text_vocalnet(batch: dict):
pass
answers = batch["supervisions"]["text"]
answer_cosyvoice_speech_token = [
cut.custom["speech_token"] for cut in batch["supervisions"]["cut"]
]
messages = []
for i in range(len(answers)):
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]},
]
messages.append(message)
return messages, answer_cosyvoice_speech_token
def process_batch_speech_continuation(batch: dict):
messages = []
for i in range(len(batch["supervisions"]["text"])):
@ -350,6 +381,131 @@ def process_batch_speech_continuation(batch: dict):
return messages
def process_batch_text_continuation(batch: dict):
messages = []
for i in range(len(batch["supervisions"]["text"])):
transcript = batch["supervisions"]["cut"][i].custom["text"]
message = [
{
"role": "user",
"content": f"Continue the following text using less than 50 words:\n\n{transcript}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": batch["supervisions"]["text"][i]},
]
messages.append(message)
return messages
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
# WAR: TODO FIXME check qwen3
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def preprocess_teacher(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
# WAR: TODO FIXME check qwen3
# THIS IS THE ONLY DIFFERENCE FROM preprocess
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
target_ids[row, col] = default_speech_token_id
# remove default_speech_token_id from target_ids and input_ids
batch_size = target_ids.size(0)
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
@ -374,72 +530,6 @@ def compute_loss(
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
# We set allowed_excess_duration_ratio=0.1.
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
# WAR: TODO FIXME check qwen3
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
# max_frames = params.max_duration * 1000 // params.frame_shift_ms
# allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
# batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = next(model.parameters()).device
feature = batch["inputs"]
@ -452,8 +542,12 @@ def compute_loss(
messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch)
elif params.dataset_format == "vocalnet":
messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch)
if params.loss_type == "kl_div":
messages_text = process_batch_text_vocalnet(batch)
elif params.dataset_format == "speech_continuation":
messages = process_batch_speech_continuation(batch)
if params.loss_type == "kl_div":
messages_text = process_batch_text_continuation(batch)
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
@ -464,12 +558,30 @@ def compute_loss(
with torch.set_grad_enabled(is_training):
if not params.enable_speech_output:
loss, acc = model(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
)
if params.loss_type == "ce":
loss, acc = model(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
)
elif params.loss_type == "kl_div":
(
teacher_input_ids,
teacher_attention_mask,
teacher_target_ids,
) = preprocess_teacher(messages_text, tokenizer)
loss, acc, acc_teacher = model.forward_kl_div(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
teacher_input_ids=teacher_input_ids.to(device),
teacher_attention_mask=teacher_attention_mask.to(device),
teacher_labels=teacher_target_ids.to(device),
)
else:
raise ValueError(f"Unknown loss type: {params.loss_type}")
else:
(
text_loss,
@ -498,6 +610,8 @@ def compute_loss(
info["acc"] = (
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
if params.loss_type == "kl_div":
info["acc_teacher"] = acc_teacher * info["frames"]
if params.enable_speech_output:
info["codec_acc"] = codec_acc * info["frames"]
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
@ -820,6 +934,17 @@ def run(rank, world_size, args):
codec_lm.config.mask_token_id = codec_vocab_size - 4
else:
codec_lm = None
if params.loss_type == "kl_div":
teacher_llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
for name, param in teacher_llm.named_parameters():
param.requires_grad = False
teacher_llm.eval()
else:
teacher_llm = None
model = SPEECH_LLM(
speech_encoder,
@ -827,6 +952,7 @@ def run(rank, world_size, args):
encoder_projector,
codec_lm,
codec_lm_padding_side="left" if params.use_flash_attn else "right",
teacher_llm=teacher_llm,
)
if params.pretrained_model_path:
@ -834,7 +960,9 @@ def run(rank, world_size, args):
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
# set params.batch_idx_train according to the checkpoint name
if "checkpoint-" in params.pretrained_model_path:
params.batch_idx_train = int(params.pretrained_model_path.split("-")[-1])
params.batch_idx_train = int(
params.pretrained_model_path.split("-")[-1].split("/")[0]
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -893,9 +1021,14 @@ def run(rank, world_size, args):
train_cuts = data_module.train_cuts_en_vocalnet()
valid_cuts = data_module.valid_cuts_en_vocalnet()
elif params.dataset_format == "speech_continuation":
train_cuts = data_module.train_cuts_ultravox()
# train_cuts = data_module.train_cuts_gigaspeech()
# train_cuts = data_module.train_cuts_librispeech()
if params.dataset == "multi_en":
train_cuts = data_module.train_cuts_ultravox()
elif params.dataset == "librispeech":
train_cuts = data_module.train_cuts_librispeech()
elif params.dataset == "gigaspeech":
train_cuts = data_module.train_cuts_gigaspeech()
else:
raise ValueError(f"Unknown dataset: {params.dataset}")
valid_cuts = data_module.valid_cuts_ultravox()
else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}")