add codec decode

This commit is contained in:
Yuekai Zhang 2025-04-21 17:57:57 +08:00
parent 09d81b44a7
commit 23fdef2fd3
3 changed files with 250 additions and 18 deletions

View File

@ -346,10 +346,19 @@ def decode_one_batch(
messages.append(message)
input_ids, attention_mask = preprocess(messages, tokenizer)
generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
if params.enable_speech_output:
generated_ids, generated_speech_output = model.decode_with_speech_output(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
for cut_id, speech_output in zip(cut_ids, generated_speech_output):
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
#torchaudio.save(save_path, speech_output.cpu(), 16000)
print(f"speech_output: {speech_output}, cut_id: {cut_id}")
else:
generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return {"beam-search": hyps}
@ -586,10 +595,71 @@ def main():
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else:
attn_implementation = "eager"
torch_dtype = torch.float16
# codec_lm = AutoModelForCausalLM.from_pretrained(
# params.llm_path_or_name,
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 8192
config = Qwen2Config(
vocab_size=codec_vocab_size,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=16,
num_key_value_heads=16,
intermediate_size=2048,
max_position_embeddings=4096,
)
# codec_lm = Qwen2ForCausalLM(config=config)
# Pass attn_implementation and torch_dtype to the constructor
# Use AutoModelForCausalLM.from_config for more generality
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype
)
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
codec_lm = get_peft_model(codec_lm, lora_config)
codec_lm.print_trainable_parameters()
else:
codec_lm = None
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
)
if params.avg > 1:

View File

@ -1,6 +1,7 @@
import torch
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
from typing import List, Tuple # Added for type hints
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@ -444,6 +445,168 @@ class SPEECH_LLM(nn.Module):
# )
return generated_ids
def decode_with_speech_output(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None, # Prompt input_ids
attention_mask: torch.Tensor = None, # Prompt attention_mask
max_text_new_tokens: int = 1024,
max_speech_new_tokens: int = 1024, # Max length for speech tokens
llm_kwargs: dict = None, # Kwargs for text LLM generate
codec_lm_kwargs: dict = None # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
) -> Tuple[torch.LongTensor, List[List[int]]]:
"""
Generates text and corresponding speech tokens using the revised logic.
Args:
fbank: Input audio features.
input_ids: Input token IDs for the text prompt.
attention_mask: Attention mask for the text prompt.
max_text_new_tokens: Max new tokens for text generation.
max_speech_new_tokens: Max new tokens for speech generation.
llm_kwargs: Additional arguments for self.llm.generate.
codec_lm_kwargs: Additional arguments for self.codec_lm.generate.
Returns:
Tuple[torch.LongTensor, List[List[int]]]:
- generated_text_ids: Tensor of generated text token IDs (including prompt).
- generated_speech_tokens: List of lists, where each inner list contains
the generated speech codec tokens for a batch item.
"""
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
device = next(self.parameters()).device # Use model's device
batch_size = fbank.shape[0]
# --- 1. Prepare Prompt Embeddings ---
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
# Merge speech features with prompt embeddings
(
merged_prompt_inputs_embeds,
merged_prompt_attention_mask,
_,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, prompt_embeds, input_ids, attention_mask
)
# --- 2. Generate Text using LLM ---
# Use merged embeds/mask as input to generate
# Ensure kwargs passed are suitable for llm.generate
# Note: Using default generation params from `decode` if not provided in kwargs
final_llm_kwargs = {
"bos_token_id": self.llm.config.bos_token_id,
"eos_token_id": self.llm.config.eos_token_id,
"pad_token_id": self.llm.config.pad_token_id,
"num_beams": 1,
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
"top_p": 0.5,
"top_k": 20,
"repetition_penalty": 1.1,
"temperature": 0.7,
**(llm_kwargs or {}) # User-provided kwargs override defaults
}
text_outputs = self.llm.generate(
inputs_embeds=merged_prompt_inputs_embeds,
attention_mask=merged_prompt_attention_mask,
max_new_tokens=max_text_new_tokens,
return_dict_in_generate=True,
**final_llm_kwargs
)
for key in text_outputs:
print(key, text_outputs[key].shape)
# Assume text_outputs is the tensor of generated IDs [B, S_full]
generated_text_ids = text_outputs
exit(0)
# --- 3. Get LLM Hidden States for the *Full* Generated Text Sequence ---
# Run a separate forward pass to reliably get hidden states for the complete sequence.
# This is simpler than parsing the complex output of generate with output_hidden_states=True.
full_text_embeds = self.llm.get_input_embeddings()(generated_text_ids) # [B, S_full, D_llm]
full_text_attention_mask = (generated_text_ids != self.llm.config.pad_token_id).long() # [B, S_full]
# --- 4. Project Hidden States ---
projected_text_embeds = self.speech_token_projector(full_text_embeds) # Shape [B, S_full, D_codec]
# --- 5. Generate Speech Tokens (Autoregressive Loop with Text Context) ---
self.codec_lm.to(device)
self.codec_lm_head.to(device)
# Initial input for the codec LM is the BOS token
current_speech_input_ids = torch.full(
(batch_size, 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=device
)
past_key_values = None
generated_speech_tokens_list = [[] for _ in range(batch_size)]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
text_context_len = projected_text_embeds.shape[1] # S_full
for t in range(max_speech_new_tokens):
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
# Add the projected text embedding corresponding to the current timestep `t`
if t < text_context_len:
# Text context from the full generated text sequence
current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
inputs_embeds = current_speech_embeds + current_text_context_embed
else:
# No more text context to add
inputs_embeds = current_speech_embeds
# Ensure inputs_embeds has the correct dtype for the codec_lm
inputs_embeds = inputs_embeds.to(next(self.codec_lm.parameters()).dtype)
# Forward pass through codec LM for one step
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
codec_outputs = self.codec_lm(
inputs_embeds=inputs_embeds, # Combined embedding for this step
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
# No attention mask needed here when using past_key_values and single token input
)
# Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(codec_outputs.last_hidden_state[:, -1:, :]) # Use -1 index
# --- Process Output & Update State ---
# Greedy decoding (can be replaced with sampling based on codec_lm_kwargs)
# TODO: Implement sampling/beam search for codec LM if needed
next_token_ids = torch.argmax(next_token_logits, dim=-1) # Greedy [B, 1]
# Mask out finished sequences
next_token_ids = next_token_ids * unfinished_sequences.unsqueeze(-1) + \
self.codec_lm.config.pad_token_id * (1 - unfinished_sequences.unsqueeze(-1))
# Store generated tokens for unfinished sequences
for i in range(batch_size):
if unfinished_sequences[i]:
token_id = next_token_ids[i].item()
if token_id == self.codec_lm.config.eos_token_id:
unfinished_sequences[i] = 0 # Mark as finished
elif token_id != self.codec_lm.config.pad_token_id:
generated_speech_tokens_list[i].append(token_id)
# Prepare for next iteration
current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
past_key_values = codec_outputs.past_key_values # Update KV cache
# Stop if all sequences are finished
if unfinished_sequences.max() == 0:
break
# --- 6. Return Results ---
return generated_text_ids, generated_speech_tokens_list
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.

View File

@ -133,20 +133,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to use lora to fine-tune llm.",
)
parser.add_argument(
"--unfreeze-llm",
type=str2bool,
default=False,
help="Whether to unfreeze llm during training.",
)
parser.add_argument(
"--unfreeze-speech-projector",
type=str2bool,
default=False,
help="Whether to unfreeze speech adaptor during training.",
)
parser.add_argument(
"--enable-speech-output",
type=str2bool,
@ -224,6 +210,19 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--unfreeze-llm",
type=str2bool,
default=False,
help="Whether to unfreeze llm during training.",
)
parser.add_argument(
"--unfreeze-speech-projector",
type=str2bool,
default=False,
help="Whether to unfreeze speech adaptor during training.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)