refactor code

This commit is contained in:
root 2025-04-25 05:36:18 +00:00
parent 2e9be46703
commit 3642dfd8c3
4 changed files with 171 additions and 154 deletions

View File

@ -51,9 +51,10 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: "
exp_dir=./slam_omni/exp_speech2speech_rerun
python3 ./slam_omni/decode.py \
--max-duration 1 \
--exp-dir slam_omni/exp_speech2speech_test_flash_attn \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 997 --avg 1 \
@ -87,21 +88,23 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: "
ngpu=2
exp_dir=./slam_omni/exp_speech2speech_test_flash_attn
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 40 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn True \
--pretrained-model-path $exp_dir/epoch-1-checkpoint-35000.pt/pytorch_model.bin \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
ngpu=8
exp_dir=./slam_omni/exp_speech2speech_rerun
# exp_dir_new=./slam_omni/exp_s2s
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn True \
--pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \
--sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
fi

View File

@ -579,7 +579,7 @@ def main():
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 8192
codec_vocab_size = 4096 + 4
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
@ -603,24 +603,25 @@ def main():
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
codec_lm = get_peft_model(codec_lm, lora_config)
codec_lm.print_trainable_parameters()
codec_lm.config.mask_token_id = codec_vocab_size - 4
# if params.use_lora:
# lora_config = LoraConfig(
# r=64,
# lora_alpha=16,
# target_modules=[
# "q_proj",
# "k_proj",
# "v_proj",
# "o_proj",
# "up_proj",
# "gate_proj",
# "down_proj",
# ],
# lora_dropout=0.05,
# task_type="CAUSAL_LM",
# )
# codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else:
codec_lm = None

View File

@ -4,7 +4,7 @@ from transformers.trainer_pt_utils import LabelSmoother
from typing import List, Tuple # Added for type hints
from torchmetrics.classification import MulticlassAccuracy
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import logging
class EncoderProjector(nn.Module):
"""
@ -69,7 +69,7 @@ class SPEECH_LLM(nn.Module):
self.codec_lm = codec_lm
if self.codec_lm:
self.speech_token_projector = nn.Linear(
self.llm.config.hidden_size, self.codec_lm.config.hidden_size
self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size
)
self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
@ -274,110 +274,92 @@ class SPEECH_LLM(nn.Module):
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
# get the label start_index in inputs_embeds from labels
text_label_start_index_list = []
input_seq_len = attention_mask.sum(dim=1) # shape, B
text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], []
for i in range(labels.shape[0]):
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
text_label_start_index_list.append(text_label_start_index)
# TODO1: check text_label_start_index position
print(i, input_ids[i], input_ids[i].shape, labels[i], labels[i].shape, text_label_start_index, labels[i][text_label_start_index])
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
input_embeds_start_index = input_embeds_valid_index[0]
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
text_labels_start_index = text_labels_valid_index[0]
assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
input_question_len = text_labels_start_index - input_embeds_start_index
assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i]
text_label_start_index_list.append(text_labels_start_index)
text_input_start_index_list.append(input_embeds_start_index)
input_question_len_list.append(input_question_len)
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
)
text_loss = model_outputs.loss
delay_step = 1
# prepare codec lm inputs
audio_codes_lens = torch.tensor(
[len(x) for x in speech_codec_ids], dtype=torch.int64, device=input_ids.device
)
# print(audio_codes_lens, "audio_codes_lens")
audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)]
max_len_speech_codec = max(audio_codes_lens)
delay_step = 2
audio_codes = torch.full(
(inputs_embeds.shape[0], max_len_speech_codec + inputs_embeds.shape[1] + 1),
self.codec_lm.config.pad_token_id,
if self.codec_lm_padding_side == "right":
audio_codes = [
[self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
elif self.codec_lm_padding_side == "left":
audio_codes = [
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id]
for i, x in enumerate(speech_codec_ids)
]
audio_codes = torch.tensor(
audio_codes,
dtype=torch.int64,
device=input_ids.device
)
audio_labels = torch.tensor(
audio_labels,
dtype=torch.int64,
device=input_ids.device
)
audio_labels = audio_codes.clone()
total_len = audio_codes.shape[1]
for i, speech_codec in enumerate(speech_codec_ids):
text_label_start_index = text_label_start_index_list[i]
speech_codec = torch.tensor(
speech_codec, dtype=torch.int64, device=input_ids.device
)
speech_codec_len = len(speech_codec)
# Calculate lengths of non-padding content
codes_len = text_label_start_index + delay_step + 1 + speech_codec_len
# Actual label content length (speech codec tokens + eos token)
labels_actual_content_len = speech_codec_len + 1
if self.codec_lm_padding_side == "right":
# Fill audio_codes (right padding)
codes_end_idx = codes_len
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
# Fill audio_labels (right padding)
labels_start_idx = text_label_start_index + delay_step
labels_speech_end_idx = labels_start_idx + speech_codec_len
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
elif self.codec_lm_padding_side == "left":
# Calculate start indices for left padding (shifting content to the right)
codes_start_idx = total_len - codes_len
labels_start_idx = total_len - labels_actual_content_len # Start index for the actual label content
# Fill audio_codes (left padding)
codes_speech_start_idx = codes_start_idx + text_label_start_index + delay_step + 1
audio_codes[i, codes_start_idx : codes_speech_start_idx] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, codes_speech_start_idx : total_len] = speech_codec
# Fill audio_labels (left padding)
labels_speech_end_idx = labels_start_idx + speech_codec_len
# Note: The beginning part remains pad_token_id
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
else:
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) # TODO: do we need to change bos tokens to pad token or mask token?
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
# input_ids: seq_len T1, audio_codec seq_len T2
text_last_hidden_outputs = model_outputs.hidden_states[-1]
text_input_embeds = inputs_embeds + text_last_hidden_outputs # TODO: 计算不对output tokens' embedding?
text_input_embeds = self.speech_token_projector(text_input_embeds)
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
for i in range(len(text_label_start_index_list)):
text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1]
text_last_hidden_lists.append(text_last_hidden)
text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos
text_embeds_list.append(text_embed)
T_merged = text_input_embeds.shape[1]
T_audio = audio_embeddings.shape[1]
if self.codec_lm_padding_side == "right":
# Add to the beginning for right padding
audio_embeddings[:, :T_merged] += text_input_embeds
elif self.codec_lm_padding_side == "left":
# Need to add to the shifted position for left padding
# Calculate the length of the non-padded sequence for each item
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
print(seq_lens[0], audio_codes[0], "======================")
for i in range(audio_embeddings.shape[0]):
item_len = seq_lens[i].item() # Get the non-padded length for item i
start_idx_content = T_audio - item_len # Start index of the content for item i
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
# Add the text_input_embeds to the calculated slice
if end_idx_target > T_audio:
# If the text input is longer than the audio input, we need to pad the audio input
cut_off_len = T_audio - start_idx_content
audio_embeddings[i, start_idx_content:end_idx_target] = text_input_embeds[i, :cut_off_len]
else:
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
else:
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
text_input_embeds = torch.cat(
[
text_last_hidden,
text_embed,
],
dim=-1,
)# shape, T, D1 + D2
text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec
text_input_embeds_list.append(text_input_embeds)
for i in range(audio_embeddings.shape[0]):
text_input_embeds = text_input_embeds_list[i]
if self.codec_lm_padding_side == "right":
audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds
elif self.codec_lm_padding_side == "left":
start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0]
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx]
logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}")
audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds
speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask,
@ -545,26 +527,56 @@ class SPEECH_LLM(nn.Module):
output_hidden_states=True,
**final_llm_kwargs
)
delay_step = 1
generated_text_ids = text_outputs.sequences # [B, S_full]
thinker_token_embeds = [
eos_token_id = self.llm.config.eos_token_id
eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D
assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
thinker_token_embeds_org = [
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
]
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
first_thinker_token_embed = torch.cat(
[
thinker_token_embeds_org[0][:, 1:],
thinker_token_embeds_org[1],
],
dim=1,
)
thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding]
thinker_hidden_states = [
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
]
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
thinker_reply_part = [torch.cat(
[
thinker_hidden_state,
thinker_token_embed,
],
dim=-1,
)
for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:])
]
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
thinker_prompt_part = torch.cat(
[
thinker_hidden_states[0],
thinker_token_embeds[0],
],
dim=-1,
)
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
delay_step = 2
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device
)
talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
thinker_input_embeds = torch.cat(
[
@ -614,7 +626,7 @@ class SPEECH_LLM(nn.Module):
# Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
# suppress tokens between 4096:len(vocab)-3
next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
next_token_ids = topk_sampling(
next_token_logits,
)

View File

@ -745,7 +745,7 @@ def run(rank, world_size, args):
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 8192
codec_vocab_size = 4096 + 4
# TODO: modify above vocab size or supress_tokens when decoding
config = Qwen2Config(
vocab_size=codec_vocab_size,
@ -770,24 +770,25 @@ def run(rank, world_size, args):
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
codec_lm = get_peft_model(codec_lm, lora_config)
codec_lm.print_trainable_parameters()
codec_lm.config.mask_token_id = codec_vocab_size - 4
# if params.use_lora:
# lora_config = LoraConfig(
# r=64,
# lora_alpha=16,
# target_modules=[
# "q_proj",
# "k_proj",
# "v_proj",
# "o_proj",
# "up_proj",
# "gate_proj",
# "down_proj",
# ],
# lora_dropout=0.05,
# task_type="CAUSAL_LM",
# )
# codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else:
codec_lm = None