mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
debug
This commit is contained in:
parent
478d56efd8
commit
2e9be46703
@ -59,7 +59,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
--epoch 997 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method small_test_speech2speech \
|
||||
--method small_test_speech2speech_rerun \
|
||||
--enable-speech-output True \
|
||||
--use-lora True # --on-the-fly-feats True
|
||||
|
||||
|
@ -314,10 +314,13 @@ def decode_one_batch(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
with open("test.txt", 'w') as f:
|
||||
for cut_id in cut_ids:
|
||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||
for cut_id in cut_ids:
|
||||
speech_token_file_name = (
|
||||
params.log_dir / f"{cut_id}.txt"
|
||||
)
|
||||
with open(speech_token_file_name, 'w') as f:
|
||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||
save_str = " ".join([str(i) for i in generated_speech_output])
|
||||
f.write(f"{cut_id}|{save_str}\n")
|
||||
@ -328,7 +331,6 @@ def decode_one_batch(
|
||||
)
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
|
||||
print(f"hyps: {hyps}")
|
||||
exit(0)
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
|
@ -280,6 +280,8 @@ class SPEECH_LLM(nn.Module):
|
||||
for i in range(labels.shape[0]):
|
||||
text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0]
|
||||
text_label_start_index_list.append(text_label_start_index)
|
||||
# TODO1: check text_label_start_index position
|
||||
print(i, input_ids[i], input_ids[i].shape, labels[i], labels[i].shape, text_label_start_index, labels[i][text_label_start_index])
|
||||
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
|
||||
@ -316,7 +318,7 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
if self.codec_lm_padding_side == "right":
|
||||
# Fill audio_codes (right padding)
|
||||
codes_end_idx = text_label_start_index + delay_step + 1 + speech_codec_len
|
||||
codes_end_idx = codes_len
|
||||
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
||||
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
|
||||
|
||||
@ -349,7 +351,7 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
# input_ids: seq_len T1, audio_codec seq_len T2
|
||||
text_last_hidden_outputs = model_outputs.hidden_states[-1]
|
||||
text_input_embeds = inputs_embeds + text_last_hidden_outputs
|
||||
text_input_embeds = inputs_embeds + text_last_hidden_outputs # TODO: 计算不对,output tokens' embedding?
|
||||
text_input_embeds = self.speech_token_projector(text_input_embeds)
|
||||
|
||||
T_merged = text_input_embeds.shape[1]
|
||||
@ -362,6 +364,7 @@ class SPEECH_LLM(nn.Module):
|
||||
# Need to add to the shifted position for left padding
|
||||
# Calculate the length of the non-padded sequence for each item
|
||||
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
|
||||
print(seq_lens[0], audio_codes[0], "======================")
|
||||
for i in range(audio_embeddings.shape[0]):
|
||||
item_len = seq_lens[i].item() # Get the non-padded length for item i
|
||||
start_idx_content = T_audio - item_len # Start index of the content for item i
|
||||
@ -560,18 +563,18 @@ class SPEECH_LLM(nn.Module):
|
||||
delay_step = 2
|
||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||
talker_input_ids = torch.full(
|
||||
(batch_size, thinker_prompt_part_seq_len + delay_step), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
|
||||
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
|
||||
)
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
||||
thinker_input_embeds = torch.cat(
|
||||
[
|
||||
thinker_prompt_part,
|
||||
thinker_reply_part[:, :delay_step, :],
|
||||
thinker_reply_part[:, :delay_step + 1, :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
talker_inputs_embeds += thinker_input_embeds
|
||||
thinker_reply_part = thinker_reply_part[:, delay_step:, :] # [B, S_full, D_codec]
|
||||
thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec]
|
||||
|
||||
|
||||
past_key_values = None
|
||||
@ -583,7 +586,7 @@ class SPEECH_LLM(nn.Module):
|
||||
for t in range(max_speech_new_tokens):
|
||||
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
||||
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
||||
if next_token_ids is not None:
|
||||
if t > 0:
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
|
||||
if thinker_reply_part.shape[1] > 0:
|
||||
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
||||
@ -607,11 +610,11 @@ class SPEECH_LLM(nn.Module):
|
||||
output_hidden_states=True,
|
||||
# No attention mask needed here when using past_key_values and single token input
|
||||
)
|
||||
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec]
|
||||
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here
|
||||
# Get logits for the *last* token generated in this step
|
||||
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
||||
# suppress tokens between 4096:len(vocab)-3
|
||||
next_token_logits[:, 4096:-3] = -float("Inf")
|
||||
next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
||||
next_token_ids = topk_sampling(
|
||||
next_token_logits,
|
||||
)
|
||||
|
@ -355,7 +355,7 @@ def compute_loss(
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
# + 2 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
||||
# + 6 to skip: 'assistant', '\n' 151665, 151645, 198, 151644, 77091, 198
|
||||
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
Loading…
x
Reference in New Issue
Block a user