mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add codec decode
This commit is contained in:
parent
09d81b44a7
commit
23fdef2fd3
@ -346,10 +346,19 @@ def decode_one_batch(
|
|||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
input_ids, attention_mask = preprocess(messages, tokenizer)
|
input_ids, attention_mask = preprocess(messages, tokenizer)
|
||||||
|
if params.enable_speech_output:
|
||||||
generated_ids = model.decode(
|
generated_ids, generated_speech_output = model.decode_with_speech_output(
|
||||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
)
|
)
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
for cut_id, speech_output in zip(cut_ids, generated_speech_output):
|
||||||
|
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||||
|
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||||
|
print(f"speech_output: {speech_output}, cut_id: {cut_id}")
|
||||||
|
else:
|
||||||
|
generated_ids = model.decode(
|
||||||
|
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
|
)
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
return {"beam-search": hyps}
|
return {"beam-search": hyps}
|
||||||
@ -586,10 +595,71 @@ def main():
|
|||||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.enable_speech_output:
|
||||||
|
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||||
|
if params.use_flash_attn:
|
||||||
|
attn_implementation = "flash_attention_2"
|
||||||
|
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
||||||
|
else:
|
||||||
|
attn_implementation = "eager"
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
|
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
||||||
|
# params.llm_path_or_name,
|
||||||
|
# attn_implementation=attn_implementation,
|
||||||
|
# torch_dtype=torch_dtype,
|
||||||
|
# )
|
||||||
|
codec_vocab_size = 8192
|
||||||
|
config = Qwen2Config(
|
||||||
|
vocab_size=codec_vocab_size,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
intermediate_size=2048,
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
)
|
||||||
|
# codec_lm = Qwen2ForCausalLM(config=config)
|
||||||
|
# Pass attn_implementation and torch_dtype to the constructor
|
||||||
|
# Use AutoModelForCausalLM.from_config for more generality
|
||||||
|
codec_lm = AutoModelForCausalLM.from_config(
|
||||||
|
config=config,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
# cosyvoice2_token_size = 6561
|
||||||
|
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||||
|
codec_lm.vocab_size = codec_vocab_size
|
||||||
|
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||||
|
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||||
|
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||||
|
if params.use_lora:
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=64,
|
||||||
|
lora_alpha=16,
|
||||||
|
target_modules=[
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"up_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"down_proj",
|
||||||
|
],
|
||||||
|
lora_dropout=0.05,
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
codec_lm = get_peft_model(codec_lm, lora_config)
|
||||||
|
codec_lm.print_trainable_parameters()
|
||||||
|
else:
|
||||||
|
codec_lm = None
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
llm,
|
llm,
|
||||||
encoder_projector,
|
encoder_projector,
|
||||||
|
codec_lm,
|
||||||
|
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.trainer_pt_utils import LabelSmoother
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
|
from typing import List, Tuple # Added for type hints
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
|
|
||||||
@ -444,6 +445,168 @@ class SPEECH_LLM(nn.Module):
|
|||||||
# )
|
# )
|
||||||
return generated_ids
|
return generated_ids
|
||||||
|
|
||||||
|
def decode_with_speech_output(
|
||||||
|
self,
|
||||||
|
fbank: torch.Tensor = None,
|
||||||
|
input_ids: torch.LongTensor = None, # Prompt input_ids
|
||||||
|
attention_mask: torch.Tensor = None, # Prompt attention_mask
|
||||||
|
max_text_new_tokens: int = 1024,
|
||||||
|
max_speech_new_tokens: int = 1024, # Max length for speech tokens
|
||||||
|
llm_kwargs: dict = None, # Kwargs for text LLM generate
|
||||||
|
codec_lm_kwargs: dict = None # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
|
||||||
|
) -> Tuple[torch.LongTensor, List[List[int]]]:
|
||||||
|
"""
|
||||||
|
Generates text and corresponding speech tokens using the revised logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fbank: Input audio features.
|
||||||
|
input_ids: Input token IDs for the text prompt.
|
||||||
|
attention_mask: Attention mask for the text prompt.
|
||||||
|
max_text_new_tokens: Max new tokens for text generation.
|
||||||
|
max_speech_new_tokens: Max new tokens for speech generation.
|
||||||
|
llm_kwargs: Additional arguments for self.llm.generate.
|
||||||
|
codec_lm_kwargs: Additional arguments for self.codec_lm.generate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.LongTensor, List[List[int]]]:
|
||||||
|
- generated_text_ids: Tensor of generated text token IDs (including prompt).
|
||||||
|
- generated_speech_tokens: List of lists, where each inner list contains
|
||||||
|
the generated speech codec tokens for a batch item.
|
||||||
|
"""
|
||||||
|
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
|
||||||
|
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
|
||||||
|
|
||||||
|
device = next(self.parameters()).device # Use model's device
|
||||||
|
batch_size = fbank.shape[0]
|
||||||
|
|
||||||
|
# --- 1. Prepare Prompt Embeddings ---
|
||||||
|
encoder_outs = self.encoder(fbank)
|
||||||
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
|
||||||
|
|
||||||
|
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
# Merge speech features with prompt embeddings
|
||||||
|
(
|
||||||
|
merged_prompt_inputs_embeds,
|
||||||
|
merged_prompt_attention_mask,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self._merge_input_ids_with_speech_features(
|
||||||
|
speech_features, prompt_embeds, input_ids, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 2. Generate Text using LLM ---
|
||||||
|
# Use merged embeds/mask as input to generate
|
||||||
|
# Ensure kwargs passed are suitable for llm.generate
|
||||||
|
# Note: Using default generation params from `decode` if not provided in kwargs
|
||||||
|
final_llm_kwargs = {
|
||||||
|
"bos_token_id": self.llm.config.bos_token_id,
|
||||||
|
"eos_token_id": self.llm.config.eos_token_id,
|
||||||
|
"pad_token_id": self.llm.config.pad_token_id,
|
||||||
|
"num_beams": 1,
|
||||||
|
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
|
||||||
|
"top_p": 0.5,
|
||||||
|
"top_k": 20,
|
||||||
|
"repetition_penalty": 1.1,
|
||||||
|
"temperature": 0.7,
|
||||||
|
**(llm_kwargs or {}) # User-provided kwargs override defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
text_outputs = self.llm.generate(
|
||||||
|
inputs_embeds=merged_prompt_inputs_embeds,
|
||||||
|
attention_mask=merged_prompt_attention_mask,
|
||||||
|
max_new_tokens=max_text_new_tokens,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
**final_llm_kwargs
|
||||||
|
)
|
||||||
|
for key in text_outputs:
|
||||||
|
print(key, text_outputs[key].shape)
|
||||||
|
# Assume text_outputs is the tensor of generated IDs [B, S_full]
|
||||||
|
generated_text_ids = text_outputs
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
# --- 3. Get LLM Hidden States for the *Full* Generated Text Sequence ---
|
||||||
|
# Run a separate forward pass to reliably get hidden states for the complete sequence.
|
||||||
|
# This is simpler than parsing the complex output of generate with output_hidden_states=True.
|
||||||
|
full_text_embeds = self.llm.get_input_embeddings()(generated_text_ids) # [B, S_full, D_llm]
|
||||||
|
full_text_attention_mask = (generated_text_ids != self.llm.config.pad_token_id).long() # [B, S_full]
|
||||||
|
|
||||||
|
# --- 4. Project Hidden States ---
|
||||||
|
projected_text_embeds = self.speech_token_projector(full_text_embeds) # Shape [B, S_full, D_codec]
|
||||||
|
|
||||||
|
# --- 5. Generate Speech Tokens (Autoregressive Loop with Text Context) ---
|
||||||
|
self.codec_lm.to(device)
|
||||||
|
self.codec_lm_head.to(device)
|
||||||
|
|
||||||
|
# Initial input for the codec LM is the BOS token
|
||||||
|
current_speech_input_ids = torch.full(
|
||||||
|
(batch_size, 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_values = None
|
||||||
|
generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
||||||
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
text_context_len = projected_text_embeds.shape[1] # S_full
|
||||||
|
|
||||||
|
for t in range(max_speech_new_tokens):
|
||||||
|
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
||||||
|
current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
||||||
|
|
||||||
|
# Add the projected text embedding corresponding to the current timestep `t`
|
||||||
|
if t < text_context_len:
|
||||||
|
# Text context from the full generated text sequence
|
||||||
|
current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
|
||||||
|
inputs_embeds = current_speech_embeds + current_text_context_embed
|
||||||
|
else:
|
||||||
|
# No more text context to add
|
||||||
|
inputs_embeds = current_speech_embeds
|
||||||
|
|
||||||
|
# Ensure inputs_embeds has the correct dtype for the codec_lm
|
||||||
|
inputs_embeds = inputs_embeds.to(next(self.codec_lm.parameters()).dtype)
|
||||||
|
|
||||||
|
# Forward pass through codec LM for one step
|
||||||
|
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
|
||||||
|
codec_outputs = self.codec_lm(
|
||||||
|
inputs_embeds=inputs_embeds, # Combined embedding for this step
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
return_dict=True,
|
||||||
|
# No attention mask needed here when using past_key_values and single token input
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get logits for the *last* token generated in this step
|
||||||
|
next_token_logits = self.codec_lm_head(codec_outputs.last_hidden_state[:, -1:, :]) # Use -1 index
|
||||||
|
|
||||||
|
# --- Process Output & Update State ---
|
||||||
|
# Greedy decoding (can be replaced with sampling based on codec_lm_kwargs)
|
||||||
|
# TODO: Implement sampling/beam search for codec LM if needed
|
||||||
|
next_token_ids = torch.argmax(next_token_logits, dim=-1) # Greedy [B, 1]
|
||||||
|
|
||||||
|
# Mask out finished sequences
|
||||||
|
next_token_ids = next_token_ids * unfinished_sequences.unsqueeze(-1) + \
|
||||||
|
self.codec_lm.config.pad_token_id * (1 - unfinished_sequences.unsqueeze(-1))
|
||||||
|
|
||||||
|
# Store generated tokens for unfinished sequences
|
||||||
|
for i in range(batch_size):
|
||||||
|
if unfinished_sequences[i]:
|
||||||
|
token_id = next_token_ids[i].item()
|
||||||
|
if token_id == self.codec_lm.config.eos_token_id:
|
||||||
|
unfinished_sequences[i] = 0 # Mark as finished
|
||||||
|
elif token_id != self.codec_lm.config.pad_token_id:
|
||||||
|
generated_speech_tokens_list[i].append(token_id)
|
||||||
|
|
||||||
|
# Prepare for next iteration
|
||||||
|
current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
||||||
|
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||||
|
|
||||||
|
# Stop if all sequences are finished
|
||||||
|
if unfinished_sequences.max() == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# --- 6. Return Results ---
|
||||||
|
return generated_text_ids, generated_speech_tokens_list
|
||||||
|
|
||||||
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||||
"""Calculate accuracy.
|
"""Calculate accuracy.
|
||||||
|
@ -133,20 +133,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Whether to use lora to fine-tune llm.",
|
help="Whether to use lora to fine-tune llm.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--unfreeze-llm",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to unfreeze llm during training.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--unfreeze-speech-projector",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to unfreeze speech adaptor during training.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-speech-output",
|
"--enable-speech-output",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -224,6 +210,19 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--unfreeze-llm",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to unfreeze llm during training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--unfreeze-speech-projector",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to unfreeze speech adaptor during training.",
|
||||||
|
)
|
||||||
parser = deepspeed.add_config_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user