This commit is contained in:
root 2025-04-24 08:24:11 +00:00
parent 478d56efd8
commit 2e9be46703
4 changed files with 20 additions and 15 deletions

View File

@ -59,7 +59,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--epoch 997 --avg 1 \ --epoch 997 --avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--method small_test_speech2speech \ --method small_test_speech2speech_rerun \
--enable-speech-output True \ --enable-speech-output True \
--use-lora True # --on-the-fly-feats True --use-lora True # --on-the-fly-feats True

View File

@ -314,8 +314,11 @@ def decode_one_batch(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
) )
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
with open("test.txt", 'w') as f:
for cut_id in cut_ids: for cut_id in cut_ids:
speech_token_file_name = (
params.log_dir / f"{cut_id}.txt"
)
with open(speech_token_file_name, 'w') as f:
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav" # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
#torchaudio.save(save_path, speech_output.cpu(), 16000) #torchaudio.save(save_path, speech_output.cpu(), 16000)
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}") print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
@ -328,7 +331,6 @@ def decode_one_batch(
) )
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
print(f"hyps: {hyps}") print(f"hyps: {hyps}")
exit(0)
return {"beam-search": hyps} return {"beam-search": hyps}

View File

@ -280,6 +280,8 @@ class SPEECH_LLM(nn.Module):
for i in range(labels.shape[0]): for i in range(labels.shape[0]):
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0] text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
text_label_start_index_list.append(text_label_start_index) text_label_start_index_list.append(text_label_start_index)
# TODO1: check text_label_start_index position
print(i, input_ids[i], input_ids[i].shape, labels[i], labels[i].shape, text_label_start_index, labels[i][text_label_start_index])
model_outputs = self.llm( model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
@ -316,7 +318,7 @@ class SPEECH_LLM(nn.Module):
if self.codec_lm_padding_side == "right": if self.codec_lm_padding_side == "right":
# Fill audio_codes (right padding) # Fill audio_codes (right padding)
codes_end_idx = text_label_start_index + delay_step + 1 + speech_codec_len codes_end_idx = codes_len
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
@ -349,7 +351,7 @@ class SPEECH_LLM(nn.Module):
# input_ids: seq_len T1, audio_codec seq_len T2 # input_ids: seq_len T1, audio_codec seq_len T2
text_last_hidden_outputs = model_outputs.hidden_states[-1] text_last_hidden_outputs = model_outputs.hidden_states[-1]
text_input_embeds = inputs_embeds + text_last_hidden_outputs text_input_embeds = inputs_embeds + text_last_hidden_outputs # TODO: 计算不对output tokens' embedding?
text_input_embeds = self.speech_token_projector(text_input_embeds) text_input_embeds = self.speech_token_projector(text_input_embeds)
T_merged = text_input_embeds.shape[1] T_merged = text_input_embeds.shape[1]
@ -362,6 +364,7 @@ class SPEECH_LLM(nn.Module):
# Need to add to the shifted position for left padding # Need to add to the shifted position for left padding
# Calculate the length of the non-padded sequence for each item # Calculate the length of the non-padded sequence for each item
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B) seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
print(seq_lens[0], audio_codes[0], "======================")
for i in range(audio_embeddings.shape[0]): for i in range(audio_embeddings.shape[0]):
item_len = seq_lens[i].item() # Get the non-padded length for item i item_len = seq_lens[i].item() # Get the non-padded length for item i
start_idx_content = T_audio - item_len # Start index of the content for item i start_idx_content = T_audio - item_len # Start index of the content for item i
@ -560,18 +563,18 @@ class SPEECH_LLM(nn.Module):
delay_step = 2 delay_step = 2
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1] thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full( talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device (batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
) )
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec] talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
thinker_input_embeds = torch.cat( thinker_input_embeds = torch.cat(
[ [
thinker_prompt_part, thinker_prompt_part,
thinker_reply_part[:, :delay_step, :], thinker_reply_part[:, :delay_step + 1, :],
], ],
dim=1, dim=1,
) )
talker_inputs_embeds += thinker_input_embeds talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[:, delay_step:, :] # [B, S_full, D_codec] thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec]
past_key_values = None past_key_values = None
@ -583,7 +586,7 @@ class SPEECH_LLM(nn.Module):
for t in range(max_speech_new_tokens): for t in range(max_speech_new_tokens):
# Get embedding for the *current* input token ID (initially BOS, then generated tokens) # Get embedding for the *current* input token ID (initially BOS, then generated tokens)
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec] # current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
if next_token_ids is not None: if t > 0:
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec] talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
if thinker_reply_part.shape[1] > 0: if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :] talker_inputs_embeds += thinker_reply_part[:, :1, :]
@ -607,11 +610,11 @@ class SPEECH_LLM(nn.Module):
output_hidden_states=True, output_hidden_states=True,
# No attention mask needed here when using past_key_values and single token input # No attention mask needed here when using past_key_values and single token input
) )
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here
# Get logits for the *last* token generated in this step # Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
# suppress tokens between 4096:len(vocab)-3 # suppress tokens between 4096:len(vocab)-3
next_token_logits[:, 4096:-3] = -float("Inf") next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
next_token_ids = topk_sampling( next_token_ids = topk_sampling(
next_token_logits, next_token_logits,
) )

View File

@ -355,7 +355,7 @@ def compute_loss(
for i in range(mask_indices[0].size(0)): for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i] row = mask_indices[0][i]
col = mask_indices[1][i] col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198 # + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
target_ids[row, : col + 6] = IGNORE_TOKEN_ID target_ids[row, : col + 6] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id) attention_mask = input_ids.ne(tokenizer.pad_token_id)