mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
fix bugs when padding right
This commit is contained in:
parent
23fdef2fd3
commit
478d56efd8
@ -35,8 +35,8 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "Stage 3: Combine features"
|
log "Stage 2: Combine features"
|
||||||
manifest_dir=data/fbank
|
manifest_dir=data/fbank
|
||||||
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
|
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
|
||||||
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
|
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
|
||||||
@ -48,39 +48,27 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|
||||||
log "stage 2: "
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "stage 3: "
|
||||||
python3 ./slam_omni/decode.py \
|
python3 ./slam_omni/decode.py \
|
||||||
--max-duration 80 \
|
--max-duration 1 \
|
||||||
--exp-dir slam_omni/exp_speech2text \
|
--exp-dir slam_omni/exp_speech2speech_test_flash_attn \
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--epoch 999 --avg 1 \
|
--epoch 997 --avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--method pure_text_sampling \
|
--method small_test_speech2speech \
|
||||||
|
--enable-speech-output True \
|
||||||
--use-lora True # --on-the-fly-feats True
|
--use-lora True # --on-the-fly-feats True
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
|
||||||
log "stage 2: "
|
|
||||||
python3 ./slam_omni/decode.py \
|
|
||||||
--max-duration 80 \
|
|
||||||
--exp-dir slam_omni/exp_speech2text \
|
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
|
||||||
--epoch 999 --avg 1 \
|
|
||||||
--manifest-dir data/fbank \
|
|
||||||
--use-flash-attn True \
|
|
||||||
--method pure_text_sampling_original_0.5B \
|
|
||||||
--use-lora False # --on-the-fly-feats True
|
|
||||||
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "stage 4: "
|
||||||
ngpu=8
|
ngpu=8
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|
||||||
log "stage 3: "
|
|
||||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||||
--max-duration 80 \
|
--max-duration 80 \
|
||||||
--enable-musan False \
|
--enable-musan False \
|
||||||
@ -97,21 +85,23 @@ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "stage 4: "
|
log "stage 5: "
|
||||||
ngpu=2
|
ngpu=2
|
||||||
|
exp_dir=./slam_omni/exp_speech2speech_test_flash_attn
|
||||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||||
--max-duration 40 \
|
--max-duration 40 \
|
||||||
--enable-musan False \
|
--enable-musan False \
|
||||||
--exp-dir ./slam_omni/exp_speech2text \
|
--exp-dir $exp_dir \
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||||
--use-flash-attn False \
|
--use-flash-attn True \
|
||||||
--use-lora True --unfreeze-llm False --enable-speech-output True
|
--pretrained-model-path $exp_dir/epoch-1-checkpoint-35000.pt/pytorch_model.bin \
|
||||||
|
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||||
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||||
# --sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
|
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
|
||||||
|
|
||||||
fi
|
fi
|
@ -62,9 +62,9 @@ from model import SPEECH_LLM, EncoderProjector
|
|||||||
|
|
||||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
from train import DEFAULT_SPEECH_TOKEN
|
from train import DEFAULT_SPEECH_TOKEN
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
from train import add_model_arguments
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -126,43 +126,6 @@ def average_checkpoints(
|
|||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
|
||||||
parser.add_argument(
|
|
||||||
"--llm-path-or-name",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Path or name of the large language model.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--speech-encoder-path-or-name",
|
|
||||||
type=str,
|
|
||||||
default="whisper-large-v2",
|
|
||||||
help="Path or name of the speech encoder.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-projector-ds-rate",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Downsample rate for the encoder projector.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-flash-attn",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to use flash attention.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-lora",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to use lora fine-tuned llm checkpoint.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -351,16 +314,21 @@ def decode_one_batch(
|
|||||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
)
|
)
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
for cut_id, speech_output in zip(cut_ids, generated_speech_output):
|
with open("test.txt", 'w') as f:
|
||||||
|
for cut_id in cut_ids:
|
||||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||||
print(f"speech_output: {speech_output}, cut_id: {cut_id}")
|
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||||
|
save_str = " ".join([str(i) for i in generated_speech_output])
|
||||||
|
f.write(f"{cut_id}|{save_str}\n")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
generated_ids = model.decode(
|
generated_ids = model.decode(
|
||||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
)
|
)
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
|
||||||
|
print(f"hyps: {hyps}")
|
||||||
|
exit(0)
|
||||||
return {"beam-search": hyps}
|
return {"beam-search": hyps}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.trainer_pt_utils import LabelSmoother
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
from typing import List, Tuple # Added for type hints
|
from typing import List, Tuple # Added for type hints
|
||||||
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
|
|
||||||
|
|
||||||
@ -74,9 +74,21 @@ class SPEECH_LLM(nn.Module):
|
|||||||
self.codec_lm_head = nn.Linear(
|
self.codec_lm_head = nn.Linear(
|
||||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||||
)
|
)
|
||||||
|
# to torch.float16
|
||||||
|
self.speech_token_projector = self.speech_token_projector.to(
|
||||||
|
dtype=torch.float16
|
||||||
|
)
|
||||||
|
self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16)
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss()
|
self.loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
self.codec_lm_padding_side = codec_lm_padding_side
|
self.codec_lm_padding_side = codec_lm_padding_side
|
||||||
|
|
||||||
|
self.audio_accuracy_metric = MulticlassAccuracy(
|
||||||
|
self.codec_lm.vocab_size,
|
||||||
|
top_k=10,
|
||||||
|
average="micro",
|
||||||
|
multidim_average="global",
|
||||||
|
ignore_index=IGNORE_TOKEN_ID,
|
||||||
|
)
|
||||||
def _merge_input_ids_with_speech_features(
|
def _merge_input_ids_with_speech_features(
|
||||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||||
):
|
):
|
||||||
@ -332,7 +344,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||||
|
|
||||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) # TODO: do we need to change bos tokens to pad token or mask token?
|
||||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||||
|
|
||||||
# input_ids: seq_len T1, audio_codec seq_len T2
|
# input_ids: seq_len T1, audio_codec seq_len T2
|
||||||
@ -355,6 +367,11 @@ class SPEECH_LLM(nn.Module):
|
|||||||
start_idx_content = T_audio - item_len # Start index of the content for item i
|
start_idx_content = T_audio - item_len # Start index of the content for item i
|
||||||
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
|
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
|
||||||
# Add the text_input_embeds to the calculated slice
|
# Add the text_input_embeds to the calculated slice
|
||||||
|
if end_idx_target > T_audio:
|
||||||
|
# If the text input is longer than the audio input, we need to pad the audio input
|
||||||
|
cut_off_len = T_audio - start_idx_content
|
||||||
|
audio_embeddings[i, start_idx_content:end_idx_target] = text_input_embeds[i, :cut_off_len]
|
||||||
|
else:
|
||||||
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
|
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||||
@ -389,9 +406,12 @@ class SPEECH_LLM(nn.Module):
|
|||||||
audio_labels.detach(),
|
audio_labels.detach(),
|
||||||
ignore_label=IGNORE_TOKEN_ID,
|
ignore_label=IGNORE_TOKEN_ID,
|
||||||
)
|
)
|
||||||
|
audio_topk_acc = self.audio_accuracy_metric(
|
||||||
|
audio_logits.detach(),
|
||||||
|
audio_labels.detach()).item()
|
||||||
|
|
||||||
|
|
||||||
return text_loss, acc, codec_loss, audio_acc
|
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
@ -473,6 +493,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
- generated_speech_tokens: List of lists, where each inner list contains
|
- generated_speech_tokens: List of lists, where each inner list contains
|
||||||
the generated speech codec tokens for a batch item.
|
the generated speech codec tokens for a batch item.
|
||||||
"""
|
"""
|
||||||
|
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
|
||||||
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
|
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
|
||||||
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
|
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
|
||||||
|
|
||||||
@ -518,93 +539,88 @@ class SPEECH_LLM(nn.Module):
|
|||||||
attention_mask=merged_prompt_attention_mask,
|
attention_mask=merged_prompt_attention_mask,
|
||||||
max_new_tokens=max_text_new_tokens,
|
max_new_tokens=max_text_new_tokens,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
output_hidden_states=True,
|
||||||
**final_llm_kwargs
|
**final_llm_kwargs
|
||||||
)
|
)
|
||||||
for key in text_outputs:
|
|
||||||
print(key, text_outputs[key].shape)
|
|
||||||
# Assume text_outputs is the tensor of generated IDs [B, S_full]
|
|
||||||
generated_text_ids = text_outputs
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
# --- 3. Get LLM Hidden States for the *Full* Generated Text Sequence ---
|
generated_text_ids = text_outputs.sequences # [B, S_full]
|
||||||
# Run a separate forward pass to reliably get hidden states for the complete sequence.
|
thinker_token_embeds = [
|
||||||
# This is simpler than parsing the complex output of generate with output_hidden_states=True.
|
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||||
full_text_embeds = self.llm.get_input_embeddings()(generated_text_ids) # [B, S_full, D_llm]
|
]
|
||||||
full_text_attention_mask = (generated_text_ids != self.llm.config.pad_token_id).long() # [B, S_full]
|
thinker_hidden_states = [
|
||||||
|
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||||
|
]
|
||||||
|
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||||
|
thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||||
|
|
||||||
# --- 4. Project Hidden States ---
|
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
|
||||||
projected_text_embeds = self.speech_token_projector(full_text_embeds) # Shape [B, S_full, D_codec]
|
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
|
||||||
|
|
||||||
# --- 5. Generate Speech Tokens (Autoregressive Loop with Text Context) ---
|
|
||||||
self.codec_lm.to(device)
|
|
||||||
self.codec_lm_head.to(device)
|
|
||||||
|
|
||||||
# Initial input for the codec LM is the BOS token
|
delay_step = 2
|
||||||
current_speech_input_ids = torch.full(
|
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||||
(batch_size, 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=device
|
talker_input_ids = torch.full(
|
||||||
|
(batch_size, thinker_prompt_part_seq_len + delay_step), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
|
||||||
)
|
)
|
||||||
|
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
||||||
|
thinker_input_embeds = torch.cat(
|
||||||
|
[
|
||||||
|
thinker_prompt_part,
|
||||||
|
thinker_reply_part[:, :delay_step, :],
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
talker_inputs_embeds += thinker_input_embeds
|
||||||
|
thinker_reply_part = thinker_reply_part[:, delay_step:, :] # [B, S_full, D_codec]
|
||||||
|
|
||||||
|
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
||||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
# unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
||||||
|
generated_speech_tokens_list = []
|
||||||
text_context_len = projected_text_embeds.shape[1] # S_full
|
next_token_ids = None
|
||||||
|
# text_context_len = projected_text_embeds.shape[1] # S_full
|
||||||
for t in range(max_speech_new_tokens):
|
for t in range(max_speech_new_tokens):
|
||||||
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
||||||
current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
||||||
|
if next_token_ids is not None:
|
||||||
# Add the projected text embedding corresponding to the current timestep `t`
|
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
|
||||||
if t < text_context_len:
|
if thinker_reply_part.shape[1] > 0:
|
||||||
# Text context from the full generated text sequence
|
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
||||||
current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
|
thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step
|
||||||
inputs_embeds = current_speech_embeds + current_text_context_embed
|
# # Add the projected text embedding corresponding to the current timestep `t`
|
||||||
else:
|
# if t < text_context_len:
|
||||||
# No more text context to add
|
# # Text context from the full generated text sequence
|
||||||
inputs_embeds = current_speech_embeds
|
# current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
|
||||||
|
# inputs_embeds = current_speech_embeds + current_text_context_embed
|
||||||
# Ensure inputs_embeds has the correct dtype for the codec_lm
|
# else:
|
||||||
inputs_embeds = inputs_embeds.to(next(self.codec_lm.parameters()).dtype)
|
# # No more text context to add
|
||||||
|
# inputs_embeds = current_speech_embeds
|
||||||
|
|
||||||
# Forward pass through codec LM for one step
|
# Forward pass through codec LM for one step
|
||||||
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
|
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
|
||||||
codec_outputs = self.codec_lm(
|
codec_outputs = self.codec_lm(
|
||||||
inputs_embeds=inputs_embeds, # Combined embedding for this step
|
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
output_hidden_states=True,
|
||||||
# No attention mask needed here when using past_key_values and single token input
|
# No attention mask needed here when using past_key_values and single token input
|
||||||
)
|
)
|
||||||
|
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec]
|
||||||
# Get logits for the *last* token generated in this step
|
# Get logits for the *last* token generated in this step
|
||||||
next_token_logits = self.codec_lm_head(codec_outputs.last_hidden_state[:, -1:, :]) # Use -1 index
|
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
||||||
|
# suppress tokens between 4096:len(vocab)-3
|
||||||
# --- Process Output & Update State ---
|
next_token_logits[:, 4096:-3] = -float("Inf")
|
||||||
# Greedy decoding (can be replaced with sampling based on codec_lm_kwargs)
|
next_token_ids = topk_sampling(
|
||||||
# TODO: Implement sampling/beam search for codec LM if needed
|
next_token_logits,
|
||||||
next_token_ids = torch.argmax(next_token_logits, dim=-1) # Greedy [B, 1]
|
)
|
||||||
|
print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
|
||||||
# Mask out finished sequences
|
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
||||||
next_token_ids = next_token_ids * unfinished_sequences.unsqueeze(-1) + \
|
|
||||||
self.codec_lm.config.pad_token_id * (1 - unfinished_sequences.unsqueeze(-1))
|
|
||||||
|
|
||||||
# Store generated tokens for unfinished sequences
|
|
||||||
for i in range(batch_size):
|
|
||||||
if unfinished_sequences[i]:
|
|
||||||
token_id = next_token_ids[i].item()
|
|
||||||
if token_id == self.codec_lm.config.eos_token_id:
|
|
||||||
unfinished_sequences[i] = 0 # Mark as finished
|
|
||||||
elif token_id != self.codec_lm.config.pad_token_id:
|
|
||||||
generated_speech_tokens_list[i].append(token_id)
|
|
||||||
|
|
||||||
# Prepare for next iteration
|
|
||||||
current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
|
||||||
past_key_values = codec_outputs.past_key_values # Update KV cache
|
|
||||||
|
|
||||||
# Stop if all sequences are finished
|
|
||||||
if unfinished_sequences.max() == 0:
|
|
||||||
break
|
break
|
||||||
|
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
||||||
|
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||||
|
generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0])
|
||||||
# --- 6. Return Results ---
|
# --- 6. Return Results ---
|
||||||
return generated_text_ids, generated_speech_tokens_list
|
return generated_text_ids, generated_speech_tokens_list
|
||||||
|
|
||||||
@ -626,3 +642,64 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
|||||||
)
|
)
|
||||||
denominator = torch.sum(mask)
|
denominator = torch.sum(mask)
|
||||||
return numerator.float() / denominator.float()
|
return numerator.float() / denominator.float()
|
||||||
|
|
||||||
|
|
||||||
|
def topk_sampling(
|
||||||
|
logits,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.95,
|
||||||
|
temperature=0.8,
|
||||||
|
):
|
||||||
|
if temperature != 1.0:
|
||||||
|
logits = logits / temperature
|
||||||
|
# Top-p/top-k filtering
|
||||||
|
logits_filtered = top_k_top_p_filtering(
|
||||||
|
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||||
|
)
|
||||||
|
# Sample
|
||||||
|
probs = torch.nn.functional.softmax(logits_filtered, dim=-1)
|
||||||
|
tokens = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||||
|
def top_k_top_p_filtering(
|
||||||
|
logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1
|
||||||
|
):
|
||||||
|
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
|
Args:
|
||||||
|
logits: logits distribution shape (batch size, vocabulary size)
|
||||||
|
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||||
|
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||||
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||||
|
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||||
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||||
|
"""
|
||||||
|
if top_k > 0:
|
||||||
|
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(
|
||||||
|
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
if min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
|
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||||||
|
# Shift the indices to the right to keep also the first token above the threshold
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
return logits
|
@ -417,7 +417,7 @@ def compute_loss(
|
|||||||
labels=target_ids.to(device),
|
labels=target_ids.to(device),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text_loss, acc, codec_loss, codec_acc = model.forward_with_speech_output(
|
text_loss, acc, codec_loss, codec_acc, codec_topk_acc = model.forward_with_speech_output(
|
||||||
fbank=feature,
|
fbank=feature,
|
||||||
input_ids=input_ids.to(device),
|
input_ids=input_ids.to(device),
|
||||||
attention_mask=attention_mask.to(device),
|
attention_mask=attention_mask.to(device),
|
||||||
@ -442,6 +442,9 @@ def compute_loss(
|
|||||||
info["codec_acc"] = (
|
info["codec_acc"] = (
|
||||||
codec_acc * info["frames"]
|
codec_acc * info["frames"]
|
||||||
)
|
)
|
||||||
|
info["codec_topk_acc"] = (
|
||||||
|
codec_topk_acc * info["frames"]
|
||||||
|
)
|
||||||
info["codec_loss"] = codec_loss.detach().cpu().item()
|
info["codec_loss"] = codec_loss.detach().cpu().item()
|
||||||
info["text_loss"] = text_loss.detach().cpu().item()
|
info["text_loss"] = text_loss.detach().cpu().item()
|
||||||
return loss, info
|
return loss, info
|
||||||
@ -743,6 +746,7 @@ def run(rank, world_size, args):
|
|||||||
# torch_dtype=torch_dtype,
|
# torch_dtype=torch_dtype,
|
||||||
# )
|
# )
|
||||||
codec_vocab_size = 8192
|
codec_vocab_size = 8192
|
||||||
|
# TODO: modify above vocab size or supress_tokens when decoding
|
||||||
config = Qwen2Config(
|
config = Qwen2Config(
|
||||||
vocab_size=codec_vocab_size,
|
vocab_size=codec_vocab_size,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user