mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
refactor code
This commit is contained in:
parent
2e9be46703
commit
3642dfd8c3
@ -51,9 +51,10 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "stage 3: "
|
log "stage 3: "
|
||||||
|
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||||
python3 ./slam_omni/decode.py \
|
python3 ./slam_omni/decode.py \
|
||||||
--max-duration 1 \
|
--max-duration 1 \
|
||||||
--exp-dir slam_omni/exp_speech2speech_test_flash_attn \
|
--exp-dir $exp_dir \
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--epoch 997 --avg 1 \
|
--epoch 997 --avg 1 \
|
||||||
@ -87,21 +88,23 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "stage 5: "
|
log "stage 5: "
|
||||||
ngpu=2
|
ngpu=8
|
||||||
exp_dir=./slam_omni/exp_speech2speech_test_flash_attn
|
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
# exp_dir_new=./slam_omni/exp_s2s
|
||||||
--max-duration 40 \
|
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||||
--enable-musan False \
|
--max-duration 50 \
|
||||||
--exp-dir $exp_dir \
|
--enable-musan False \
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
--exp-dir $exp_dir \
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
--manifest-dir data/fbank \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--deepspeed \
|
--manifest-dir data/fbank \
|
||||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
--deepspeed \
|
||||||
--use-flash-attn True \
|
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||||
--pretrained-model-path $exp_dir/epoch-1-checkpoint-35000.pt/pytorch_model.bin \
|
--use-flash-attn True \
|
||||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
--pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \
|
||||||
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
--sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \
|
||||||
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
|
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||||
|
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||||
|
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
|
||||||
|
|
||||||
fi
|
fi
|
@ -579,7 +579,7 @@ def main():
|
|||||||
# attn_implementation=attn_implementation,
|
# attn_implementation=attn_implementation,
|
||||||
# torch_dtype=torch_dtype,
|
# torch_dtype=torch_dtype,
|
||||||
# )
|
# )
|
||||||
codec_vocab_size = 8192
|
codec_vocab_size = 4096 + 4
|
||||||
config = Qwen2Config(
|
config = Qwen2Config(
|
||||||
vocab_size=codec_vocab_size,
|
vocab_size=codec_vocab_size,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
@ -603,24 +603,25 @@ def main():
|
|||||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||||
if params.use_lora:
|
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||||
lora_config = LoraConfig(
|
# if params.use_lora:
|
||||||
r=64,
|
# lora_config = LoraConfig(
|
||||||
lora_alpha=16,
|
# r=64,
|
||||||
target_modules=[
|
# lora_alpha=16,
|
||||||
"q_proj",
|
# target_modules=[
|
||||||
"k_proj",
|
# "q_proj",
|
||||||
"v_proj",
|
# "k_proj",
|
||||||
"o_proj",
|
# "v_proj",
|
||||||
"up_proj",
|
# "o_proj",
|
||||||
"gate_proj",
|
# "up_proj",
|
||||||
"down_proj",
|
# "gate_proj",
|
||||||
],
|
# "down_proj",
|
||||||
lora_dropout=0.05,
|
# ],
|
||||||
task_type="CAUSAL_LM",
|
# lora_dropout=0.05,
|
||||||
)
|
# task_type="CAUSAL_LM",
|
||||||
codec_lm = get_peft_model(codec_lm, lora_config)
|
# )
|
||||||
codec_lm.print_trainable_parameters()
|
# codec_lm = get_peft_model(codec_lm, lora_config)
|
||||||
|
# codec_lm.print_trainable_parameters()
|
||||||
else:
|
else:
|
||||||
codec_lm = None
|
codec_lm = None
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from transformers.trainer_pt_utils import LabelSmoother
|
|||||||
from typing import List, Tuple # Added for type hints
|
from typing import List, Tuple # Added for type hints
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
|
import logging
|
||||||
|
|
||||||
class EncoderProjector(nn.Module):
|
class EncoderProjector(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -69,7 +69,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
self.codec_lm = codec_lm
|
self.codec_lm = codec_lm
|
||||||
if self.codec_lm:
|
if self.codec_lm:
|
||||||
self.speech_token_projector = nn.Linear(
|
self.speech_token_projector = nn.Linear(
|
||||||
self.llm.config.hidden_size, self.codec_lm.config.hidden_size
|
self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size
|
||||||
)
|
)
|
||||||
self.codec_lm_head = nn.Linear(
|
self.codec_lm_head = nn.Linear(
|
||||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||||
@ -274,110 +274,92 @@ class SPEECH_LLM(nn.Module):
|
|||||||
) = self._merge_input_ids_with_speech_features(
|
) = self._merge_input_ids_with_speech_features(
|
||||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
|
input_seq_len = attention_mask.sum(dim=1) # shape, B
|
||||||
# get the label start_index in inputs_embeds from labels
|
text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], []
|
||||||
text_label_start_index_list = []
|
|
||||||
for i in range(labels.shape[0]):
|
for i in range(labels.shape[0]):
|
||||||
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
|
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
|
||||||
text_label_start_index_list.append(text_label_start_index)
|
input_embeds_start_index = input_embeds_valid_index[0]
|
||||||
# TODO1: check text_label_start_index position
|
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
|
||||||
print(i, input_ids[i], input_ids[i].shape, labels[i], labels[i].shape, text_label_start_index, labels[i][text_label_start_index])
|
text_labels_start_index = text_labels_valid_index[0]
|
||||||
|
|
||||||
|
assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
|
||||||
|
assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
|
||||||
|
input_question_len = text_labels_start_index - input_embeds_start_index
|
||||||
|
assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i]
|
||||||
|
text_label_start_index_list.append(text_labels_start_index)
|
||||||
|
text_input_start_index_list.append(input_embeds_start_index)
|
||||||
|
input_question_len_list.append(input_question_len)
|
||||||
|
|
||||||
model_outputs = self.llm(
|
model_outputs = self.llm(
|
||||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
|
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
|
||||||
)
|
)
|
||||||
text_loss = model_outputs.loss
|
text_loss = model_outputs.loss
|
||||||
|
delay_step = 1
|
||||||
# prepare codec lm inputs
|
# prepare codec lm inputs
|
||||||
audio_codes_lens = torch.tensor(
|
audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)]
|
||||||
[len(x) for x in speech_codec_ids], dtype=torch.int64, device=input_ids.device
|
|
||||||
)
|
|
||||||
# print(audio_codes_lens, "audio_codes_lens")
|
|
||||||
max_len_speech_codec = max(audio_codes_lens)
|
max_len_speech_codec = max(audio_codes_lens)
|
||||||
delay_step = 2
|
|
||||||
audio_codes = torch.full(
|
if self.codec_lm_padding_side == "right":
|
||||||
(inputs_embeds.shape[0], max_len_speech_codec + inputs_embeds.shape[1] + 1),
|
audio_codes = [
|
||||||
self.codec_lm.config.pad_token_id,
|
[self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
|
||||||
|
for i, x in enumerate(speech_codec_ids)
|
||||||
|
]
|
||||||
|
audio_labels = [
|
||||||
|
[self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
|
||||||
|
for i, x in enumerate(speech_codec_ids)
|
||||||
|
]
|
||||||
|
elif self.codec_lm_padding_side == "left":
|
||||||
|
audio_codes = [
|
||||||
|
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x
|
||||||
|
for i, x in enumerate(speech_codec_ids)
|
||||||
|
]
|
||||||
|
audio_labels = [
|
||||||
|
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id]
|
||||||
|
for i, x in enumerate(speech_codec_ids)
|
||||||
|
]
|
||||||
|
audio_codes = torch.tensor(
|
||||||
|
audio_codes,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=input_ids.device
|
||||||
|
)
|
||||||
|
audio_labels = torch.tensor(
|
||||||
|
audio_labels,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=input_ids.device
|
device=input_ids.device
|
||||||
)
|
)
|
||||||
audio_labels = audio_codes.clone()
|
|
||||||
total_len = audio_codes.shape[1]
|
|
||||||
|
|
||||||
for i, speech_codec in enumerate(speech_codec_ids):
|
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||||
text_label_start_index = text_label_start_index_list[i]
|
|
||||||
speech_codec = torch.tensor(
|
|
||||||
speech_codec, dtype=torch.int64, device=input_ids.device
|
|
||||||
)
|
|
||||||
speech_codec_len = len(speech_codec)
|
|
||||||
|
|
||||||
# Calculate lengths of non-padding content
|
|
||||||
codes_len = text_label_start_index + delay_step + 1 + speech_codec_len
|
|
||||||
# Actual label content length (speech codec tokens + eos token)
|
|
||||||
labels_actual_content_len = speech_codec_len + 1
|
|
||||||
|
|
||||||
if self.codec_lm_padding_side == "right":
|
|
||||||
# Fill audio_codes (right padding)
|
|
||||||
codes_end_idx = codes_len
|
|
||||||
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
|
||||||
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
|
|
||||||
|
|
||||||
# Fill audio_labels (right padding)
|
|
||||||
labels_start_idx = text_label_start_index + delay_step
|
|
||||||
labels_speech_end_idx = labels_start_idx + speech_codec_len
|
|
||||||
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
|
|
||||||
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
|
|
||||||
|
|
||||||
elif self.codec_lm_padding_side == "left":
|
|
||||||
# Calculate start indices for left padding (shifting content to the right)
|
|
||||||
codes_start_idx = total_len - codes_len
|
|
||||||
labels_start_idx = total_len - labels_actual_content_len # Start index for the actual label content
|
|
||||||
|
|
||||||
# Fill audio_codes (left padding)
|
|
||||||
codes_speech_start_idx = codes_start_idx + text_label_start_index + delay_step + 1
|
|
||||||
audio_codes[i, codes_start_idx : codes_speech_start_idx] = self.codec_lm.config.bos_token_id # mask token_id
|
|
||||||
audio_codes[i, codes_speech_start_idx : total_len] = speech_codec
|
|
||||||
|
|
||||||
# Fill audio_labels (left padding)
|
|
||||||
labels_speech_end_idx = labels_start_idx + speech_codec_len
|
|
||||||
# Note: The beginning part remains pad_token_id
|
|
||||||
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
|
|
||||||
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
|
||||||
|
|
||||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) # TODO: do we need to change bos tokens to pad token or mask token?
|
|
||||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||||
|
|
||||||
# input_ids: seq_len T1, audio_codec seq_len T2
|
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
||||||
text_last_hidden_outputs = model_outputs.hidden_states[-1]
|
for i in range(len(text_label_start_index_list)):
|
||||||
text_input_embeds = inputs_embeds + text_last_hidden_outputs # TODO: 计算不对,output tokens' embedding?
|
text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1]
|
||||||
text_input_embeds = self.speech_token_projector(text_input_embeds)
|
text_last_hidden_lists.append(text_last_hidden)
|
||||||
|
text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos
|
||||||
|
text_embeds_list.append(text_embed)
|
||||||
|
|
||||||
T_merged = text_input_embeds.shape[1]
|
text_input_embeds = torch.cat(
|
||||||
T_audio = audio_embeddings.shape[1]
|
[
|
||||||
|
text_last_hidden,
|
||||||
if self.codec_lm_padding_side == "right":
|
text_embed,
|
||||||
# Add to the beginning for right padding
|
],
|
||||||
audio_embeddings[:, :T_merged] += text_input_embeds
|
dim=-1,
|
||||||
elif self.codec_lm_padding_side == "left":
|
)# shape, T, D1 + D2
|
||||||
# Need to add to the shifted position for left padding
|
text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec
|
||||||
# Calculate the length of the non-padded sequence for each item
|
text_input_embeds_list.append(text_input_embeds)
|
||||||
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
|
|
||||||
print(seq_lens[0], audio_codes[0], "======================")
|
for i in range(audio_embeddings.shape[0]):
|
||||||
for i in range(audio_embeddings.shape[0]):
|
text_input_embeds = text_input_embeds_list[i]
|
||||||
item_len = seq_lens[i].item() # Get the non-padded length for item i
|
if self.codec_lm_padding_side == "right":
|
||||||
start_idx_content = T_audio - item_len # Start index of the content for item i
|
audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds
|
||||||
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
|
elif self.codec_lm_padding_side == "left":
|
||||||
# Add the text_input_embeds to the calculated slice
|
start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0]
|
||||||
if end_idx_target > T_audio:
|
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
|
||||||
# If the text input is longer than the audio input, we need to pad the audio input
|
assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
||||||
cut_off_len = T_audio - start_idx_content
|
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
|
||||||
audio_embeddings[i, start_idx_content:end_idx_target] = text_input_embeds[i, :cut_off_len]
|
text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx]
|
||||||
else:
|
logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}")
|
||||||
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
|
audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
|
||||||
|
|
||||||
speech_outputs = self.codec_lm(
|
speech_outputs = self.codec_lm(
|
||||||
attention_mask=audio_attention_mask,
|
attention_mask=audio_attention_mask,
|
||||||
@ -545,26 +527,56 @@ class SPEECH_LLM(nn.Module):
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
**final_llm_kwargs
|
**final_llm_kwargs
|
||||||
)
|
)
|
||||||
|
delay_step = 1
|
||||||
generated_text_ids = text_outputs.sequences # [B, S_full]
|
generated_text_ids = text_outputs.sequences # [B, S_full]
|
||||||
thinker_token_embeds = [
|
eos_token_id = self.llm.config.eos_token_id
|
||||||
|
eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D
|
||||||
|
assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
||||||
|
thinker_token_embeds_org = [
|
||||||
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||||
]
|
]
|
||||||
|
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
|
||||||
|
first_thinker_token_embed = torch.cat(
|
||||||
|
[
|
||||||
|
thinker_token_embeds_org[0][:, 1:],
|
||||||
|
thinker_token_embeds_org[1],
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding]
|
||||||
thinker_hidden_states = [
|
thinker_hidden_states = [
|
||||||
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||||
]
|
]
|
||||||
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||||
thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
thinker_reply_part = [torch.cat(
|
||||||
|
[
|
||||||
|
thinker_hidden_state,
|
||||||
|
thinker_token_embed,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:])
|
||||||
|
]
|
||||||
|
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
|
||||||
|
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||||
|
thinker_prompt_part = torch.cat(
|
||||||
|
[
|
||||||
|
thinker_hidden_states[0],
|
||||||
|
thinker_token_embeds[0],
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
|
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
|
||||||
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
|
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
|
||||||
|
|
||||||
|
|
||||||
delay_step = 2
|
|
||||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||||
talker_input_ids = torch.full(
|
talker_input_ids = torch.full(
|
||||||
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
|
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device
|
||||||
)
|
)
|
||||||
|
talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id
|
||||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
||||||
thinker_input_embeds = torch.cat(
|
thinker_input_embeds = torch.cat(
|
||||||
[
|
[
|
||||||
@ -614,7 +626,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
# Get logits for the *last* token generated in this step
|
# Get logits for the *last* token generated in this step
|
||||||
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
||||||
# suppress tokens between 4096:len(vocab)-3
|
# suppress tokens between 4096:len(vocab)-3
|
||||||
next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
||||||
next_token_ids = topk_sampling(
|
next_token_ids = topk_sampling(
|
||||||
next_token_logits,
|
next_token_logits,
|
||||||
)
|
)
|
||||||
|
@ -745,7 +745,7 @@ def run(rank, world_size, args):
|
|||||||
# attn_implementation=attn_implementation,
|
# attn_implementation=attn_implementation,
|
||||||
# torch_dtype=torch_dtype,
|
# torch_dtype=torch_dtype,
|
||||||
# )
|
# )
|
||||||
codec_vocab_size = 8192
|
codec_vocab_size = 4096 + 4
|
||||||
# TODO: modify above vocab size or supress_tokens when decoding
|
# TODO: modify above vocab size or supress_tokens when decoding
|
||||||
config = Qwen2Config(
|
config = Qwen2Config(
|
||||||
vocab_size=codec_vocab_size,
|
vocab_size=codec_vocab_size,
|
||||||
@ -770,24 +770,25 @@ def run(rank, world_size, args):
|
|||||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||||
if params.use_lora:
|
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||||
lora_config = LoraConfig(
|
# if params.use_lora:
|
||||||
r=64,
|
# lora_config = LoraConfig(
|
||||||
lora_alpha=16,
|
# r=64,
|
||||||
target_modules=[
|
# lora_alpha=16,
|
||||||
"q_proj",
|
# target_modules=[
|
||||||
"k_proj",
|
# "q_proj",
|
||||||
"v_proj",
|
# "k_proj",
|
||||||
"o_proj",
|
# "v_proj",
|
||||||
"up_proj",
|
# "o_proj",
|
||||||
"gate_proj",
|
# "up_proj",
|
||||||
"down_proj",
|
# "gate_proj",
|
||||||
],
|
# "down_proj",
|
||||||
lora_dropout=0.05,
|
# ],
|
||||||
task_type="CAUSAL_LM",
|
# lora_dropout=0.05,
|
||||||
)
|
# task_type="CAUSAL_LM",
|
||||||
codec_lm = get_peft_model(codec_lm, lora_config)
|
# )
|
||||||
codec_lm.print_trainable_parameters()
|
# codec_lm = get_peft_model(codec_lm, lora_config)
|
||||||
|
# codec_lm.print_trainable_parameters()
|
||||||
else:
|
else:
|
||||||
codec_lm = None
|
codec_lm = None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user