mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 22:45:27 +00:00
add flash attn support
This commit is contained in:
parent
b305cdacc0
commit
7db40052d6
@ -287,19 +287,50 @@ class SPEECH_LLM(nn.Module):
|
|||||||
device=input_ids.device
|
device=input_ids.device
|
||||||
)
|
)
|
||||||
audio_labels = audio_codes.clone()
|
audio_labels = audio_codes.clone()
|
||||||
|
total_len = audio_codes.shape[1]
|
||||||
|
|
||||||
for i, speech_codec in enumerate(speech_codec_ids):
|
for i, speech_codec in enumerate(speech_codec_ids):
|
||||||
text_label_start_index = text_label_start_index_list[i]
|
text_label_start_index = text_label_start_index_list[i]
|
||||||
speech_codec = torch.tensor(
|
speech_codec = torch.tensor(
|
||||||
speech_codec, dtype=torch.int64, device=input_ids.device
|
speech_codec, dtype=torch.int64, device=input_ids.device
|
||||||
)
|
)
|
||||||
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
speech_codec_len = len(speech_codec)
|
||||||
audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec
|
|
||||||
audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec
|
# Calculate lengths of non-padding content
|
||||||
audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id
|
codes_len = text_label_start_index + delay_step + 1 + speech_codec_len
|
||||||
|
# Actual label content length (speech codec tokens + eos token)
|
||||||
|
labels_actual_content_len = speech_codec_len + 1
|
||||||
|
|
||||||
|
if self.codec_lm_padding_side == "right":
|
||||||
|
# Fill audio_codes (right padding)
|
||||||
|
codes_end_idx = text_label_start_index + delay_step + 1 + speech_codec_len
|
||||||
|
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
||||||
|
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
|
||||||
|
|
||||||
|
# Fill audio_labels (right padding)
|
||||||
|
labels_start_idx = text_label_start_index + delay_step
|
||||||
|
labels_speech_end_idx = labels_start_idx + speech_codec_len
|
||||||
|
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
|
||||||
|
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
|
||||||
|
|
||||||
|
elif self.codec_lm_padding_side == "left":
|
||||||
|
# Calculate start indices for left padding (shifting content to the right)
|
||||||
|
codes_start_idx = total_len - codes_len
|
||||||
|
labels_start_idx = total_len - labels_actual_content_len # Start index for the actual label content
|
||||||
|
|
||||||
|
# Fill audio_codes (left padding)
|
||||||
|
codes_speech_start_idx = codes_start_idx + text_label_start_index + delay_step + 1
|
||||||
|
audio_codes[i, codes_start_idx : codes_speech_start_idx] = self.codec_lm.config.bos_token_id # mask token_id
|
||||||
|
audio_codes[i, codes_speech_start_idx : total_len] = speech_codec
|
||||||
|
|
||||||
|
# Fill audio_labels (left padding)
|
||||||
|
labels_speech_end_idx = labels_start_idx + speech_codec_len
|
||||||
|
# Note: The beginning part remains pad_token_id
|
||||||
|
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
|
||||||
|
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||||
|
|
||||||
if self.codec_lm_padding_side == "left":
|
|
||||||
pass
|
|
||||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||||
|
|
||||||
@ -308,7 +339,24 @@ class SPEECH_LLM(nn.Module):
|
|||||||
text_input_embeds = inputs_embeds + text_last_hidden_outputs
|
text_input_embeds = inputs_embeds + text_last_hidden_outputs
|
||||||
text_input_embeds = self.speech_token_projector(text_input_embeds)
|
text_input_embeds = self.speech_token_projector(text_input_embeds)
|
||||||
|
|
||||||
audio_embeddings[:, : text_input_embeds.shape[1]] += text_input_embeds
|
T_merged = text_input_embeds.shape[1]
|
||||||
|
T_audio = audio_embeddings.shape[1]
|
||||||
|
|
||||||
|
if self.codec_lm_padding_side == "right":
|
||||||
|
# Add to the beginning for right padding
|
||||||
|
audio_embeddings[:, :T_merged] += text_input_embeds
|
||||||
|
elif self.codec_lm_padding_side == "left":
|
||||||
|
# Need to add to the shifted position for left padding
|
||||||
|
# Calculate the length of the non-padded sequence for each item
|
||||||
|
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
|
||||||
|
for i in range(audio_embeddings.shape[0]):
|
||||||
|
item_len = seq_lens[i].item() # Get the non-padded length for item i
|
||||||
|
start_idx_content = T_audio - item_len # Start index of the content for item i
|
||||||
|
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
|
||||||
|
# Add the text_input_embeds to the calculated slice
|
||||||
|
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||||
|
|
||||||
speech_outputs = self.codec_lm(
|
speech_outputs = self.codec_lm(
|
||||||
attention_mask=audio_attention_mask,
|
attention_mask=audio_attention_mask,
|
||||||
|
|||||||
@ -383,7 +383,7 @@ def compute_loss(
|
|||||||
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||||
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
||||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
|
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
|
||||||
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为“他是一个有善心的人。
|
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for i, total_round in enumerate(chat_rounds):
|
for i, total_round in enumerate(chat_rounds):
|
||||||
@ -730,11 +730,14 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
|
|
||||||
if params.enable_speech_output:
|
if params.enable_speech_output:
|
||||||
|
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
|
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
||||||
else:
|
else:
|
||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
||||||
# params.llm_path_or_name,
|
# params.llm_path_or_name,
|
||||||
# attn_implementation=attn_implementation,
|
# attn_implementation=attn_implementation,
|
||||||
@ -750,7 +753,14 @@ def run(rank, world_size, args):
|
|||||||
intermediate_size=2048,
|
intermediate_size=2048,
|
||||||
max_position_embeddings=4096,
|
max_position_embeddings=4096,
|
||||||
)
|
)
|
||||||
codec_lm = Qwen2ForCausalLM(config=config)
|
# codec_lm = Qwen2ForCausalLM(config=config)
|
||||||
|
# Pass attn_implementation and torch_dtype to the constructor
|
||||||
|
# Use AutoModelForCausalLM.from_config for more generality
|
||||||
|
codec_lm = AutoModelForCausalLM.from_config(
|
||||||
|
config=config,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
# cosyvoice2_token_size = 6561
|
# cosyvoice2_token_size = 6561
|
||||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||||
codec_lm.vocab_size = codec_vocab_size
|
codec_lm.vocab_size = codec_vocab_size
|
||||||
@ -829,7 +839,7 @@ def run(rank, world_size, args):
|
|||||||
return False
|
return False
|
||||||
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
|
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
|
||||||
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
|
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
|
||||||
if codec_len > 2048:
|
if codec_len > 2200:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user