mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
refactor code
This commit is contained in:
parent
2e9be46703
commit
3642dfd8c3
@ -51,9 +51,10 @@ fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: "
|
||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||
python3 ./slam_omni/decode.py \
|
||||
--max-duration 1 \
|
||||
--exp-dir slam_omni/exp_speech2speech_test_flash_attn \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 997 --avg 1 \
|
||||
@ -87,21 +88,23 @@ fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "stage 5: "
|
||||
ngpu=2
|
||||
exp_dir=./slam_omni/exp_speech2speech_test_flash_attn
|
||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--max-duration 40 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--pretrained-model-path $exp_dir/epoch-1-checkpoint-35000.pt/pytorch_model.bin \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
|
||||
ngpu=8
|
||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||
# exp_dir_new=./slam_omni/exp_s2s
|
||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \
|
||||
--sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
|
||||
|
||||
fi
|
@ -579,7 +579,7 @@ def main():
|
||||
# attn_implementation=attn_implementation,
|
||||
# torch_dtype=torch_dtype,
|
||||
# )
|
||||
codec_vocab_size = 8192
|
||||
codec_vocab_size = 4096 + 4
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
hidden_size=1024,
|
||||
@ -603,24 +603,25 @@ def main():
|
||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
codec_lm = get_peft_model(codec_lm, lora_config)
|
||||
codec_lm.print_trainable_parameters()
|
||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||
# if params.use_lora:
|
||||
# lora_config = LoraConfig(
|
||||
# r=64,
|
||||
# lora_alpha=16,
|
||||
# target_modules=[
|
||||
# "q_proj",
|
||||
# "k_proj",
|
||||
# "v_proj",
|
||||
# "o_proj",
|
||||
# "up_proj",
|
||||
# "gate_proj",
|
||||
# "down_proj",
|
||||
# ],
|
||||
# lora_dropout=0.05,
|
||||
# task_type="CAUSAL_LM",
|
||||
# )
|
||||
# codec_lm = get_peft_model(codec_lm, lora_config)
|
||||
# codec_lm.print_trainable_parameters()
|
||||
else:
|
||||
codec_lm = None
|
||||
|
||||
|
@ -4,7 +4,7 @@ from transformers.trainer_pt_utils import LabelSmoother
|
||||
from typing import List, Tuple # Added for type hints
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
|
||||
import logging
|
||||
|
||||
class EncoderProjector(nn.Module):
|
||||
"""
|
||||
@ -69,7 +69,7 @@ class SPEECH_LLM(nn.Module):
|
||||
self.codec_lm = codec_lm
|
||||
if self.codec_lm:
|
||||
self.speech_token_projector = nn.Linear(
|
||||
self.llm.config.hidden_size, self.codec_lm.config.hidden_size
|
||||
self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size
|
||||
)
|
||||
self.codec_lm_head = nn.Linear(
|
||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||
@ -274,110 +274,92 @@ class SPEECH_LLM(nn.Module):
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
# get the label start_index in inputs_embeds from labels
|
||||
text_label_start_index_list = []
|
||||
input_seq_len = attention_mask.sum(dim=1) # shape, B
|
||||
text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], []
|
||||
for i in range(labels.shape[0]):
|
||||
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
|
||||
text_label_start_index_list.append(text_label_start_index)
|
||||
# TODO1: check text_label_start_index position
|
||||
print(i, input_ids[i], input_ids[i].shape, labels[i], labels[i].shape, text_label_start_index, labels[i][text_label_start_index])
|
||||
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
|
||||
input_embeds_start_index = input_embeds_valid_index[0]
|
||||
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
|
||||
text_labels_start_index = text_labels_valid_index[0]
|
||||
|
||||
assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
|
||||
assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
|
||||
input_question_len = text_labels_start_index - input_embeds_start_index
|
||||
assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i]
|
||||
text_label_start_index_list.append(text_labels_start_index)
|
||||
text_input_start_index_list.append(input_embeds_start_index)
|
||||
input_question_len_list.append(input_question_len)
|
||||
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
|
||||
)
|
||||
text_loss = model_outputs.loss
|
||||
|
||||
delay_step = 1
|
||||
# prepare codec lm inputs
|
||||
audio_codes_lens = torch.tensor(
|
||||
[len(x) for x in speech_codec_ids], dtype=torch.int64, device=input_ids.device
|
||||
)
|
||||
# print(audio_codes_lens, "audio_codes_lens")
|
||||
audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)]
|
||||
max_len_speech_codec = max(audio_codes_lens)
|
||||
delay_step = 2
|
||||
audio_codes = torch.full(
|
||||
(inputs_embeds.shape[0], max_len_speech_codec + inputs_embeds.shape[1] + 1),
|
||||
self.codec_lm.config.pad_token_id,
|
||||
|
||||
if self.codec_lm_padding_side == "right":
|
||||
audio_codes = [
|
||||
[self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
audio_labels = [
|
||||
[self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
audio_codes = [
|
||||
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
audio_labels = [
|
||||
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id]
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
audio_codes = torch.tensor(
|
||||
audio_codes,
|
||||
dtype=torch.int64,
|
||||
device=input_ids.device
|
||||
)
|
||||
audio_labels = torch.tensor(
|
||||
audio_labels,
|
||||
dtype=torch.int64,
|
||||
device=input_ids.device
|
||||
)
|
||||
audio_labels = audio_codes.clone()
|
||||
total_len = audio_codes.shape[1]
|
||||
|
||||
for i, speech_codec in enumerate(speech_codec_ids):
|
||||
text_label_start_index = text_label_start_index_list[i]
|
||||
speech_codec = torch.tensor(
|
||||
speech_codec, dtype=torch.int64, device=input_ids.device
|
||||
)
|
||||
speech_codec_len = len(speech_codec)
|
||||
|
||||
# Calculate lengths of non-padding content
|
||||
codes_len = text_label_start_index + delay_step + 1 + speech_codec_len
|
||||
# Actual label content length (speech codec tokens + eos token)
|
||||
labels_actual_content_len = speech_codec_len + 1
|
||||
|
||||
if self.codec_lm_padding_side == "right":
|
||||
# Fill audio_codes (right padding)
|
||||
codes_end_idx = codes_len
|
||||
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
||||
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
|
||||
|
||||
# Fill audio_labels (right padding)
|
||||
labels_start_idx = text_label_start_index + delay_step
|
||||
labels_speech_end_idx = labels_start_idx + speech_codec_len
|
||||
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
|
||||
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
|
||||
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
# Calculate start indices for left padding (shifting content to the right)
|
||||
codes_start_idx = total_len - codes_len
|
||||
labels_start_idx = total_len - labels_actual_content_len # Start index for the actual label content
|
||||
|
||||
# Fill audio_codes (left padding)
|
||||
codes_speech_start_idx = codes_start_idx + text_label_start_index + delay_step + 1
|
||||
audio_codes[i, codes_start_idx : codes_speech_start_idx] = self.codec_lm.config.bos_token_id # mask token_id
|
||||
audio_codes[i, codes_speech_start_idx : total_len] = speech_codec
|
||||
|
||||
# Fill audio_labels (left padding)
|
||||
labels_speech_end_idx = labels_start_idx + speech_codec_len
|
||||
# Note: The beginning part remains pad_token_id
|
||||
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
|
||||
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
|
||||
else:
|
||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||
|
||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) # TODO: do we need to change bos tokens to pad token or mask token?
|
||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||
|
||||
# input_ids: seq_len T1, audio_codec seq_len T2
|
||||
text_last_hidden_outputs = model_outputs.hidden_states[-1]
|
||||
text_input_embeds = inputs_embeds + text_last_hidden_outputs # TODO: 计算不对,output tokens' embedding?
|
||||
text_input_embeds = self.speech_token_projector(text_input_embeds)
|
||||
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
||||
for i in range(len(text_label_start_index_list)):
|
||||
text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1]
|
||||
text_last_hidden_lists.append(text_last_hidden)
|
||||
text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos
|
||||
text_embeds_list.append(text_embed)
|
||||
|
||||
T_merged = text_input_embeds.shape[1]
|
||||
T_audio = audio_embeddings.shape[1]
|
||||
|
||||
if self.codec_lm_padding_side == "right":
|
||||
# Add to the beginning for right padding
|
||||
audio_embeddings[:, :T_merged] += text_input_embeds
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
# Need to add to the shifted position for left padding
|
||||
# Calculate the length of the non-padded sequence for each item
|
||||
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
|
||||
print(seq_lens[0], audio_codes[0], "======================")
|
||||
for i in range(audio_embeddings.shape[0]):
|
||||
item_len = seq_lens[i].item() # Get the non-padded length for item i
|
||||
start_idx_content = T_audio - item_len # Start index of the content for item i
|
||||
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
|
||||
# Add the text_input_embeds to the calculated slice
|
||||
if end_idx_target > T_audio:
|
||||
# If the text input is longer than the audio input, we need to pad the audio input
|
||||
cut_off_len = T_audio - start_idx_content
|
||||
audio_embeddings[i, start_idx_content:end_idx_target] = text_input_embeds[i, :cut_off_len]
|
||||
else:
|
||||
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
|
||||
else:
|
||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||
text_input_embeds = torch.cat(
|
||||
[
|
||||
text_last_hidden,
|
||||
text_embed,
|
||||
],
|
||||
dim=-1,
|
||||
)# shape, T, D1 + D2
|
||||
text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec
|
||||
text_input_embeds_list.append(text_input_embeds)
|
||||
|
||||
for i in range(audio_embeddings.shape[0]):
|
||||
text_input_embeds = text_input_embeds_list[i]
|
||||
if self.codec_lm_padding_side == "right":
|
||||
audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0]
|
||||
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
|
||||
assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
||||
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
|
||||
text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx]
|
||||
logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}")
|
||||
audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds
|
||||
|
||||
speech_outputs = self.codec_lm(
|
||||
attention_mask=audio_attention_mask,
|
||||
@ -545,26 +527,56 @@ class SPEECH_LLM(nn.Module):
|
||||
output_hidden_states=True,
|
||||
**final_llm_kwargs
|
||||
)
|
||||
|
||||
delay_step = 1
|
||||
generated_text_ids = text_outputs.sequences # [B, S_full]
|
||||
thinker_token_embeds = [
|
||||
eos_token_id = self.llm.config.eos_token_id
|
||||
eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D
|
||||
assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
||||
thinker_token_embeds_org = [
|
||||
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
|
||||
first_thinker_token_embed = torch.cat(
|
||||
[
|
||||
thinker_token_embeds_org[0][:, 1:],
|
||||
thinker_token_embeds_org[1],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding]
|
||||
thinker_hidden_states = [
|
||||
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||
thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||
thinker_reply_part = [torch.cat(
|
||||
[
|
||||
thinker_hidden_state,
|
||||
thinker_token_embed,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:])
|
||||
]
|
||||
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
|
||||
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||
thinker_prompt_part = torch.cat(
|
||||
[
|
||||
thinker_hidden_states[0],
|
||||
thinker_token_embeds[0],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
|
||||
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
|
||||
|
||||
|
||||
delay_step = 2
|
||||
|
||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||
talker_input_ids = torch.full(
|
||||
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
|
||||
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device
|
||||
)
|
||||
talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
||||
thinker_input_embeds = torch.cat(
|
||||
[
|
||||
@ -614,7 +626,7 @@ class SPEECH_LLM(nn.Module):
|
||||
# Get logits for the *last* token generated in this step
|
||||
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
||||
# suppress tokens between 4096:len(vocab)-3
|
||||
next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
||||
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
||||
next_token_ids = topk_sampling(
|
||||
next_token_logits,
|
||||
)
|
||||
|
@ -745,7 +745,7 @@ def run(rank, world_size, args):
|
||||
# attn_implementation=attn_implementation,
|
||||
# torch_dtype=torch_dtype,
|
||||
# )
|
||||
codec_vocab_size = 8192
|
||||
codec_vocab_size = 4096 + 4
|
||||
# TODO: modify above vocab size or supress_tokens when decoding
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
@ -770,24 +770,25 @@ def run(rank, world_size, args):
|
||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
codec_lm = get_peft_model(codec_lm, lora_config)
|
||||
codec_lm.print_trainable_parameters()
|
||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||
# if params.use_lora:
|
||||
# lora_config = LoraConfig(
|
||||
# r=64,
|
||||
# lora_alpha=16,
|
||||
# target_modules=[
|
||||
# "q_proj",
|
||||
# "k_proj",
|
||||
# "v_proj",
|
||||
# "o_proj",
|
||||
# "up_proj",
|
||||
# "gate_proj",
|
||||
# "down_proj",
|
||||
# ],
|
||||
# lora_dropout=0.05,
|
||||
# task_type="CAUSAL_LM",
|
||||
# )
|
||||
# codec_lm = get_peft_model(codec_lm, lora_config)
|
||||
# codec_lm.print_trainable_parameters()
|
||||
else:
|
||||
codec_lm = None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user