This commit is contained in:
root 2025-04-24 08:24:11 +00:00
parent 478d56efd8
commit 2e9be46703
4 changed files with 20 additions and 15 deletions

View File

@ -59,7 +59,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--epoch 997 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method small_test_speech2speech \
--method small_test_speech2speech_rerun \
--enable-speech-output True \
--use-lora True # --on-the-fly-feats True

View File

@ -314,10 +314,13 @@ def decode_one_batch(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
with open("test.txt", 'w') as f:
for cut_id in cut_ids:
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
#torchaudio.save(save_path, speech_output.cpu(), 16000)
for cut_id in cut_ids:
speech_token_file_name = (
params.log_dir / f"{cut_id}.txt"
)
with open(speech_token_file_name, 'w') as f:
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
#torchaudio.save(save_path, speech_output.cpu(), 16000)
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
save_str = " ".join([str(i) for i in generated_speech_output])
f.write(f"{cut_id}|{save_str}\n")
@ -328,7 +331,6 @@ def decode_one_batch(
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
print(f"hyps: {hyps}")
exit(0)
return {"beam-search": hyps}

View File

@ -280,6 +280,8 @@ class SPEECH_LLM(nn.Module):
for i in range(labels.shape[0]):
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
text_label_start_index_list.append(text_label_start_index)
# TODO1: check text_label_start_index position
print(i, input_ids[i], input_ids[i].shape, labels[i], labels[i].shape, text_label_start_index, labels[i][text_label_start_index])
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
@ -316,7 +318,7 @@ class SPEECH_LLM(nn.Module):
if self.codec_lm_padding_side == "right":
# Fill audio_codes (right padding)
codes_end_idx = text_label_start_index + delay_step + 1 + speech_codec_len
codes_end_idx = codes_len
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
@ -349,7 +351,7 @@ class SPEECH_LLM(nn.Module):
# input_ids: seq_len T1, audio_codec seq_len T2
text_last_hidden_outputs = model_outputs.hidden_states[-1]
text_input_embeds = inputs_embeds + text_last_hidden_outputs
text_input_embeds = inputs_embeds + text_last_hidden_outputs # TODO: 计算不对output tokens' embedding?
text_input_embeds = self.speech_token_projector(text_input_embeds)
T_merged = text_input_embeds.shape[1]
@ -362,6 +364,7 @@ class SPEECH_LLM(nn.Module):
# Need to add to the shifted position for left padding
# Calculate the length of the non-padded sequence for each item
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
print(seq_lens[0], audio_codes[0], "======================")
for i in range(audio_embeddings.shape[0]):
item_len = seq_lens[i].item() # Get the non-padded length for item i
start_idx_content = T_audio - item_len # Start index of the content for item i
@ -560,18 +563,18 @@ class SPEECH_LLM(nn.Module):
delay_step = 2
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
)
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
thinker_input_embeds = torch.cat(
[
thinker_prompt_part,
thinker_reply_part[:, :delay_step, :],
thinker_reply_part[:, :delay_step + 1, :],
],
dim=1,
)
talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[:, delay_step:, :] # [B, S_full, D_codec]
thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec]
past_key_values = None
@ -583,7 +586,7 @@ class SPEECH_LLM(nn.Module):
for t in range(max_speech_new_tokens):
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
if next_token_ids is not None:
if t > 0:
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :]
@ -607,11 +610,11 @@ class SPEECH_LLM(nn.Module):
output_hidden_states=True,
# No attention mask needed here when using past_key_values and single token input
)
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec]
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here
# Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
# suppress tokens between 4096:len(vocab)-3
next_token_logits[:, 4096:-3] = -float("Inf")
next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
next_token_ids = topk_sampling(
next_token_logits,
)

View File

@ -355,7 +355,7 @@ def compute_loss(
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)