refactor code

This commit is contained in:
root 2025-04-25 05:36:18 +00:00
parent 2e9be46703
commit 3642dfd8c3
4 changed files with 171 additions and 154 deletions

View File

@ -51,9 +51,10 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: " log "stage 3: "
exp_dir=./slam_omni/exp_speech2speech_rerun
python3 ./slam_omni/decode.py \ python3 ./slam_omni/decode.py \
--max-duration 1 \ --max-duration 1 \
--exp-dir slam_omni/exp_speech2speech_test_flash_attn \ --exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 997 --avg 1 \ --epoch 997 --avg 1 \
@ -87,21 +88,23 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: " log "stage 5: "
ngpu=2 ngpu=8
exp_dir=./slam_omni/exp_speech2speech_test_flash_attn exp_dir=./slam_omni/exp_speech2speech_rerun
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ # exp_dir_new=./slam_omni/exp_s2s
--max-duration 40 \ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--enable-musan False \ --max-duration 50 \
--exp-dir $exp_dir \ --enable-musan False \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --exp-dir $exp_dir \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--manifest-dir data/fbank \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--deepspeed \ --manifest-dir data/fbank \
--deepspeed_config ./slam_omni/ds_config_zero1.json \ --deepspeed \
--use-flash-attn True \ --deepspeed_config ./slam_omni/ds_config_zero1.json \
--pretrained-model-path $exp_dir/epoch-1-checkpoint-35000.pt/pytorch_model.bin \ --use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True --pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
fi fi

View File

@ -579,7 +579,7 @@ def main():
# attn_implementation=attn_implementation, # attn_implementation=attn_implementation,
# torch_dtype=torch_dtype, # torch_dtype=torch_dtype,
# ) # )
codec_vocab_size = 8192 codec_vocab_size = 4096 + 4
config = Qwen2Config( config = Qwen2Config(
vocab_size=codec_vocab_size, vocab_size=codec_vocab_size,
hidden_size=1024, hidden_size=1024,
@ -603,24 +603,25 @@ def main():
codec_lm.config.pad_token_id = codec_vocab_size - 1 codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2 codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3 codec_lm.config.bos_token_id = codec_vocab_size - 3
if params.use_lora: codec_lm.config.mask_token_id = codec_vocab_size - 4
lora_config = LoraConfig( # if params.use_lora:
r=64, # lora_config = LoraConfig(
lora_alpha=16, # r=64,
target_modules=[ # lora_alpha=16,
"q_proj", # target_modules=[
"k_proj", # "q_proj",
"v_proj", # "k_proj",
"o_proj", # "v_proj",
"up_proj", # "o_proj",
"gate_proj", # "up_proj",
"down_proj", # "gate_proj",
], # "down_proj",
lora_dropout=0.05, # ],
task_type="CAUSAL_LM", # lora_dropout=0.05,
) # task_type="CAUSAL_LM",
codec_lm = get_peft_model(codec_lm, lora_config) # )
codec_lm.print_trainable_parameters() # codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else: else:
codec_lm = None codec_lm = None

View File

@ -4,7 +4,7 @@ from transformers.trainer_pt_utils import LabelSmoother
from typing import List, Tuple # Added for type hints from typing import List, Tuple # Added for type hints
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import MulticlassAccuracy
IGNORE_TOKEN_ID = LabelSmoother.ignore_index IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import logging
class EncoderProjector(nn.Module): class EncoderProjector(nn.Module):
""" """
@ -69,7 +69,7 @@ class SPEECH_LLM(nn.Module):
self.codec_lm = codec_lm self.codec_lm = codec_lm
if self.codec_lm: if self.codec_lm:
self.speech_token_projector = nn.Linear( self.speech_token_projector = nn.Linear(
self.llm.config.hidden_size, self.codec_lm.config.hidden_size self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size
) )
self.codec_lm_head = nn.Linear( self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
@ -274,110 +274,92 @@ class SPEECH_LLM(nn.Module):
) = self._merge_input_ids_with_speech_features( ) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels speech_features, inputs_embeds, input_ids, attention_mask, labels
) )
input_seq_len = attention_mask.sum(dim=1) # shape, B
# get the label start_index in inputs_embeds from labels text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], []
text_label_start_index_list = []
for i in range(labels.shape[0]): for i in range(labels.shape[0]):
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0] input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
text_label_start_index_list.append(text_label_start_index) input_embeds_start_index = input_embeds_valid_index[0]
# TODO1: check text_label_start_index position text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
print(i, input_ids[i], input_ids[i].shape, labels[i], labels[i].shape, text_label_start_index, labels[i][text_label_start_index]) text_labels_start_index = text_labels_valid_index[0]
assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
input_question_len = text_labels_start_index - input_embeds_start_index
assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i]
text_label_start_index_list.append(text_labels_start_index)
text_input_start_index_list.append(input_embeds_start_index)
input_question_len_list.append(input_question_len)
model_outputs = self.llm( model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
) )
text_loss = model_outputs.loss text_loss = model_outputs.loss
delay_step = 1
# prepare codec lm inputs # prepare codec lm inputs
audio_codes_lens = torch.tensor( audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)]
[len(x) for x in speech_codec_ids], dtype=torch.int64, device=input_ids.device
)
# print(audio_codes_lens, "audio_codes_lens")
max_len_speech_codec = max(audio_codes_lens) max_len_speech_codec = max(audio_codes_lens)
delay_step = 2
audio_codes = torch.full( if self.codec_lm_padding_side == "right":
(inputs_embeds.shape[0], max_len_speech_codec + inputs_embeds.shape[1] + 1), audio_codes = [
self.codec_lm.config.pad_token_id, [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
elif self.codec_lm_padding_side == "left":
audio_codes = [
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id]
for i, x in enumerate(speech_codec_ids)
]
audio_codes = torch.tensor(
audio_codes,
dtype=torch.int64,
device=input_ids.device
)
audio_labels = torch.tensor(
audio_labels,
dtype=torch.int64, dtype=torch.int64,
device=input_ids.device device=input_ids.device
) )
audio_labels = audio_codes.clone()
total_len = audio_codes.shape[1]
for i, speech_codec in enumerate(speech_codec_ids): audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
text_label_start_index = text_label_start_index_list[i]
speech_codec = torch.tensor(
speech_codec, dtype=torch.int64, device=input_ids.device
)
speech_codec_len = len(speech_codec)
# Calculate lengths of non-padding content
codes_len = text_label_start_index + delay_step + 1 + speech_codec_len
# Actual label content length (speech codec tokens + eos token)
labels_actual_content_len = speech_codec_len + 1
if self.codec_lm_padding_side == "right":
# Fill audio_codes (right padding)
codes_end_idx = codes_len
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
# Fill audio_labels (right padding)
labels_start_idx = text_label_start_index + delay_step
labels_speech_end_idx = labels_start_idx + speech_codec_len
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
elif self.codec_lm_padding_side == "left":
# Calculate start indices for left padding (shifting content to the right)
codes_start_idx = total_len - codes_len
labels_start_idx = total_len - labels_actual_content_len # Start index for the actual label content
# Fill audio_codes (left padding)
codes_speech_start_idx = codes_start_idx + text_label_start_index + delay_step + 1
audio_codes[i, codes_start_idx : codes_speech_start_idx] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, codes_speech_start_idx : total_len] = speech_codec
# Fill audio_labels (left padding)
labels_speech_end_idx = labels_start_idx + speech_codec_len
# Note: The beginning part remains pad_token_id
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
else:
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) # TODO: do we need to change bos tokens to pad token or mask token?
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
# input_ids: seq_len T1, audio_codec seq_len T2 text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
text_last_hidden_outputs = model_outputs.hidden_states[-1] for i in range(len(text_label_start_index_list)):
text_input_embeds = inputs_embeds + text_last_hidden_outputs # TODO: 计算不对output tokens' embedding? text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1]
text_input_embeds = self.speech_token_projector(text_input_embeds) text_last_hidden_lists.append(text_last_hidden)
text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos
text_embeds_list.append(text_embed)
T_merged = text_input_embeds.shape[1] text_input_embeds = torch.cat(
T_audio = audio_embeddings.shape[1] [
text_last_hidden,
if self.codec_lm_padding_side == "right": text_embed,
# Add to the beginning for right padding ],
audio_embeddings[:, :T_merged] += text_input_embeds dim=-1,
elif self.codec_lm_padding_side == "left": )# shape, T, D1 + D2
# Need to add to the shifted position for left padding text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec
# Calculate the length of the non-padded sequence for each item text_input_embeds_list.append(text_input_embeds)
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
print(seq_lens[0], audio_codes[0], "======================") for i in range(audio_embeddings.shape[0]):
for i in range(audio_embeddings.shape[0]): text_input_embeds = text_input_embeds_list[i]
item_len = seq_lens[i].item() # Get the non-padded length for item i if self.codec_lm_padding_side == "right":
start_idx_content = T_audio - item_len # Start index of the content for item i audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content elif self.codec_lm_padding_side == "left":
# Add the text_input_embeds to the calculated slice start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0]
if end_idx_target > T_audio: start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
# If the text input is longer than the audio input, we need to pad the audio input assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
cut_off_len = T_audio - start_idx_content if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
audio_embeddings[i, start_idx_content:end_idx_target] = text_input_embeds[i, :cut_off_len] text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx]
else: logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}")
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i] audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds
else:
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
speech_outputs = self.codec_lm( speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask, attention_mask=audio_attention_mask,
@ -545,26 +527,56 @@ class SPEECH_LLM(nn.Module):
output_hidden_states=True, output_hidden_states=True,
**final_llm_kwargs **final_llm_kwargs
) )
delay_step = 1
generated_text_ids = text_outputs.sequences # [B, S_full] generated_text_ids = text_outputs.sequences # [B, S_full]
thinker_token_embeds = [ eos_token_id = self.llm.config.eos_token_id
eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D
assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
thinker_token_embeds_org = [
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
] ]
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
first_thinker_token_embed = torch.cat(
[
thinker_token_embeds_org[0][:, 1:],
thinker_token_embeds_org[1],
],
dim=1,
)
thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding]
thinker_hidden_states = [ thinker_hidden_states = [
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
] ]
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) # thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0] thinker_reply_part = [torch.cat(
[
thinker_hidden_state,
thinker_token_embed,
],
dim=-1,
)
for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:])
]
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
thinker_prompt_part = torch.cat(
[
thinker_hidden_states[0],
thinker_token_embeds[0],
],
dim=-1,
)
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec] thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec] thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
delay_step = 2
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1] thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full( talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device (batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device
) )
talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec] talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
thinker_input_embeds = torch.cat( thinker_input_embeds = torch.cat(
[ [
@ -614,7 +626,7 @@ class SPEECH_LLM(nn.Module):
# Get logits for the *last* token generated in this step # Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
# suppress tokens between 4096:len(vocab)-3 # suppress tokens between 4096:len(vocab)-3
next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens? # next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
next_token_ids = topk_sampling( next_token_ids = topk_sampling(
next_token_logits, next_token_logits,
) )

View File

@ -745,7 +745,7 @@ def run(rank, world_size, args):
# attn_implementation=attn_implementation, # attn_implementation=attn_implementation,
# torch_dtype=torch_dtype, # torch_dtype=torch_dtype,
# ) # )
codec_vocab_size = 8192 codec_vocab_size = 4096 + 4
# TODO: modify above vocab size or supress_tokens when decoding # TODO: modify above vocab size or supress_tokens when decoding
config = Qwen2Config( config = Qwen2Config(
vocab_size=codec_vocab_size, vocab_size=codec_vocab_size,
@ -770,24 +770,25 @@ def run(rank, world_size, args):
codec_lm.config.pad_token_id = codec_vocab_size - 1 codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2 codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3 codec_lm.config.bos_token_id = codec_vocab_size - 3
if params.use_lora: codec_lm.config.mask_token_id = codec_vocab_size - 4
lora_config = LoraConfig( # if params.use_lora:
r=64, # lora_config = LoraConfig(
lora_alpha=16, # r=64,
target_modules=[ # lora_alpha=16,
"q_proj", # target_modules=[
"k_proj", # "q_proj",
"v_proj", # "k_proj",
"o_proj", # "v_proj",
"up_proj", # "o_proj",
"gate_proj", # "up_proj",
"down_proj", # "gate_proj",
], # "down_proj",
lora_dropout=0.05, # ],
task_type="CAUSAL_LM", # lora_dropout=0.05,
) # task_type="CAUSAL_LM",
codec_lm = get_peft_model(codec_lm, lora_config) # )
codec_lm.print_trainable_parameters() # codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else: else:
codec_lm = None codec_lm = None