mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 19:12:30 +00:00
debug
This commit is contained in:
parent
478d56efd8
commit
2e9be46703
@ -59,7 +59,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
--epoch 997 --avg 1 \
|
--epoch 997 --avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--method small_test_speech2speech \
|
--method small_test_speech2speech_rerun \
|
||||||
--enable-speech-output True \
|
--enable-speech-output True \
|
||||||
--use-lora True # --on-the-fly-feats True
|
--use-lora True # --on-the-fly-feats True
|
||||||
|
|
||||||
|
@ -314,10 +314,13 @@ def decode_one_batch(
|
|||||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
)
|
)
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
with open("test.txt", 'w') as f:
|
for cut_id in cut_ids:
|
||||||
for cut_id in cut_ids:
|
speech_token_file_name = (
|
||||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
params.log_dir / f"{cut_id}.txt"
|
||||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
)
|
||||||
|
with open(speech_token_file_name, 'w') as f:
|
||||||
|
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||||
|
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||||
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||||
save_str = " ".join([str(i) for i in generated_speech_output])
|
save_str = " ".join([str(i) for i in generated_speech_output])
|
||||||
f.write(f"{cut_id}|{save_str}\n")
|
f.write(f"{cut_id}|{save_str}\n")
|
||||||
@ -328,7 +331,6 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
|
||||||
print(f"hyps: {hyps}")
|
print(f"hyps: {hyps}")
|
||||||
exit(0)
|
|
||||||
return {"beam-search": hyps}
|
return {"beam-search": hyps}
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,6 +280,8 @@ class SPEECH_LLM(nn.Module):
|
|||||||
for i in range(labels.shape[0]):
|
for i in range(labels.shape[0]):
|
||||||
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
|
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
|
||||||
text_label_start_index_list.append(text_label_start_index)
|
text_label_start_index_list.append(text_label_start_index)
|
||||||
|
# TODO1: check text_label_start_index position
|
||||||
|
print(i, input_ids[i], input_ids[i].shape, labels[i], labels[i].shape, text_label_start_index, labels[i][text_label_start_index])
|
||||||
|
|
||||||
model_outputs = self.llm(
|
model_outputs = self.llm(
|
||||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
|
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
|
||||||
@ -316,7 +318,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
if self.codec_lm_padding_side == "right":
|
if self.codec_lm_padding_side == "right":
|
||||||
# Fill audio_codes (right padding)
|
# Fill audio_codes (right padding)
|
||||||
codes_end_idx = text_label_start_index + delay_step + 1 + speech_codec_len
|
codes_end_idx = codes_len
|
||||||
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
||||||
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
|
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
|
||||||
|
|
||||||
@ -349,7 +351,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
# input_ids: seq_len T1, audio_codec seq_len T2
|
# input_ids: seq_len T1, audio_codec seq_len T2
|
||||||
text_last_hidden_outputs = model_outputs.hidden_states[-1]
|
text_last_hidden_outputs = model_outputs.hidden_states[-1]
|
||||||
text_input_embeds = inputs_embeds + text_last_hidden_outputs
|
text_input_embeds = inputs_embeds + text_last_hidden_outputs # TODO: 计算不对,output tokens' embedding?
|
||||||
text_input_embeds = self.speech_token_projector(text_input_embeds)
|
text_input_embeds = self.speech_token_projector(text_input_embeds)
|
||||||
|
|
||||||
T_merged = text_input_embeds.shape[1]
|
T_merged = text_input_embeds.shape[1]
|
||||||
@ -362,6 +364,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
# Need to add to the shifted position for left padding
|
# Need to add to the shifted position for left padding
|
||||||
# Calculate the length of the non-padded sequence for each item
|
# Calculate the length of the non-padded sequence for each item
|
||||||
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
|
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
|
||||||
|
print(seq_lens[0], audio_codes[0], "======================")
|
||||||
for i in range(audio_embeddings.shape[0]):
|
for i in range(audio_embeddings.shape[0]):
|
||||||
item_len = seq_lens[i].item() # Get the non-padded length for item i
|
item_len = seq_lens[i].item() # Get the non-padded length for item i
|
||||||
start_idx_content = T_audio - item_len # Start index of the content for item i
|
start_idx_content = T_audio - item_len # Start index of the content for item i
|
||||||
@ -560,18 +563,18 @@ class SPEECH_LLM(nn.Module):
|
|||||||
delay_step = 2
|
delay_step = 2
|
||||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||||
talker_input_ids = torch.full(
|
talker_input_ids = torch.full(
|
||||||
(batch_size, thinker_prompt_part_seq_len + delay_step), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
|
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
|
||||||
)
|
)
|
||||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
||||||
thinker_input_embeds = torch.cat(
|
thinker_input_embeds = torch.cat(
|
||||||
[
|
[
|
||||||
thinker_prompt_part,
|
thinker_prompt_part,
|
||||||
thinker_reply_part[:, :delay_step, :],
|
thinker_reply_part[:, :delay_step + 1, :],
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
talker_inputs_embeds += thinker_input_embeds
|
talker_inputs_embeds += thinker_input_embeds
|
||||||
thinker_reply_part = thinker_reply_part[:, delay_step:, :] # [B, S_full, D_codec]
|
thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec]
|
||||||
|
|
||||||
|
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
@ -583,7 +586,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
for t in range(max_speech_new_tokens):
|
for t in range(max_speech_new_tokens):
|
||||||
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
||||||
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
||||||
if next_token_ids is not None:
|
if t > 0:
|
||||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
|
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
|
||||||
if thinker_reply_part.shape[1] > 0:
|
if thinker_reply_part.shape[1] > 0:
|
||||||
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
||||||
@ -607,11 +610,11 @@ class SPEECH_LLM(nn.Module):
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
# No attention mask needed here when using past_key_values and single token input
|
# No attention mask needed here when using past_key_values and single token input
|
||||||
)
|
)
|
||||||
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec]
|
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here
|
||||||
# Get logits for the *last* token generated in this step
|
# Get logits for the *last* token generated in this step
|
||||||
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
||||||
# suppress tokens between 4096:len(vocab)-3
|
# suppress tokens between 4096:len(vocab)-3
|
||||||
next_token_logits[:, 4096:-3] = -float("Inf")
|
next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
||||||
next_token_ids = topk_sampling(
|
next_token_ids = topk_sampling(
|
||||||
next_token_logits,
|
next_token_logits,
|
||||||
)
|
)
|
||||||
|
@ -355,7 +355,7 @@ def compute_loss(
|
|||||||
for i in range(mask_indices[0].size(0)):
|
for i in range(mask_indices[0].size(0)):
|
||||||
row = mask_indices[0][i]
|
row = mask_indices[0][i]
|
||||||
col = mask_indices[1][i]
|
col = mask_indices[1][i]
|
||||||
# + 2 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
||||||
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||||
|
|
||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user