fix bugs when padding right

This commit is contained in:
root 2025-04-23 07:33:27 +00:00
parent 23fdef2fd3
commit 478d56efd8
4 changed files with 187 additions and 148 deletions

View File

@ -35,8 +35,8 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 3: Combine features" log "Stage 2: Combine features"
manifest_dir=data/fbank manifest_dir=data/fbank
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort) pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
@ -48,39 +48,27 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
fi fi
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: " if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: "
python3 ./slam_omni/decode.py \ python3 ./slam_omni/decode.py \
--max-duration 80 \ --max-duration 1 \
--exp-dir slam_omni/exp_speech2text \ --exp-dir slam_omni/exp_speech2speech_test_flash_attn \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \ --epoch 997 --avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--method pure_text_sampling \ --method small_test_speech2speech \
--enable-speech-output True \
--use-lora True # --on-the-fly-feats True --use-lora True # --on-the-fly-feats True
fi fi
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "stage 2: "
python3 ./slam_omni/decode.py \
--max-duration 80 \
--exp-dir slam_omni/exp_speech2text \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method pure_text_sampling_original_0.5B \
--use-lora False # --on-the-fly-feats True
fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: "
ngpu=8 ngpu=8
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: "
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 80 \ --max-duration 80 \
--enable-musan False \ --enable-musan False \
@ -97,21 +85,23 @@ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 4: " log "stage 5: "
ngpu=2 ngpu=2
exp_dir=./slam_omni/exp_speech2speech_test_flash_attn
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 40 \ --max-duration 40 \
--enable-musan False \ --enable-musan False \
--exp-dir ./slam_omni/exp_speech2text \ --exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--deepspeed \ --deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \ --deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn False \ --use-flash-attn True \
--use-lora True --unfreeze-llm False --enable-speech-output True --pretrained-model-path $exp_dir/epoch-1-checkpoint-35000.pt/pytorch_model.bin \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ # --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
# --sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \ # --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
fi fi

View File

@ -62,9 +62,9 @@ from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from train import DEFAULT_SPEECH_TOKEN from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from train import add_model_arguments
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -126,43 +126,6 @@ def average_checkpoints(
return avg return avg
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
type=str,
default="",
help="Path or name of the large language model.",
)
parser.add_argument(
"--speech-encoder-path-or-name",
type=str,
default="whisper-large-v2",
help="Path or name of the speech encoder.",
)
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=8,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser.add_argument(
"--use-lora",
type=str2bool,
default=True,
help="Whether to use lora fine-tuned llm checkpoint.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -351,16 +314,21 @@ def decode_one_batch(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
) )
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
for cut_id, speech_output in zip(cut_ids, generated_speech_output): with open("test.txt", 'w') as f:
for cut_id in cut_ids:
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav" # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
#torchaudio.save(save_path, speech_output.cpu(), 16000) #torchaudio.save(save_path, speech_output.cpu(), 16000)
print(f"speech_output: {speech_output}, cut_id: {cut_id}") print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
save_str = " ".join([str(i) for i in generated_speech_output])
f.write(f"{cut_id}|{save_str}\n")
else: else:
generated_ids = model.decode( generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
) )
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
print(f"hyps: {hyps}")
exit(0)
return {"beam-search": hyps} return {"beam-search": hyps}

View File

@ -2,7 +2,7 @@ import torch
from torch import nn from torch import nn
from transformers.trainer_pt_utils import LabelSmoother from transformers.trainer_pt_utils import LabelSmoother
from typing import List, Tuple # Added for type hints from typing import List, Tuple # Added for type hints
from torchmetrics.classification import MulticlassAccuracy
IGNORE_TOKEN_ID = LabelSmoother.ignore_index IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@ -74,9 +74,21 @@ class SPEECH_LLM(nn.Module):
self.codec_lm_head = nn.Linear( self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
) )
# to torch.float16
self.speech_token_projector = self.speech_token_projector.to(
dtype=torch.float16
)
self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16)
self.loss_fct = torch.nn.CrossEntropyLoss() self.loss_fct = torch.nn.CrossEntropyLoss()
self.codec_lm_padding_side = codec_lm_padding_side self.codec_lm_padding_side = codec_lm_padding_side
self.audio_accuracy_metric = MulticlassAccuracy(
self.codec_lm.vocab_size,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=IGNORE_TOKEN_ID,
)
def _merge_input_ids_with_speech_features( def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
): ):
@ -332,7 +344,7 @@ class SPEECH_LLM(nn.Module):
else: else:
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}") raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) # TODO: do we need to change bos tokens to pad token or mask token?
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
# input_ids: seq_len T1, audio_codec seq_len T2 # input_ids: seq_len T1, audio_codec seq_len T2
@ -355,6 +367,11 @@ class SPEECH_LLM(nn.Module):
start_idx_content = T_audio - item_len # Start index of the content for item i start_idx_content = T_audio - item_len # Start index of the content for item i
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
# Add the text_input_embeds to the calculated slice # Add the text_input_embeds to the calculated slice
if end_idx_target > T_audio:
# If the text input is longer than the audio input, we need to pad the audio input
cut_off_len = T_audio - start_idx_content
audio_embeddings[i, start_idx_content:end_idx_target] = text_input_embeds[i, :cut_off_len]
else:
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i] audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
else: else:
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}") raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
@ -389,9 +406,12 @@ class SPEECH_LLM(nn.Module):
audio_labels.detach(), audio_labels.detach(),
ignore_label=IGNORE_TOKEN_ID, ignore_label=IGNORE_TOKEN_ID,
) )
audio_topk_acc = self.audio_accuracy_metric(
audio_logits.detach(),
audio_labels.detach()).item()
return text_loss, acc, codec_loss, audio_acc return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
def decode( def decode(
self, self,
@ -473,6 +493,7 @@ class SPEECH_LLM(nn.Module):
- generated_speech_tokens: List of lists, where each inner list contains - generated_speech_tokens: List of lists, where each inner list contains
the generated speech codec tokens for a batch item. the generated speech codec tokens for a batch item.
""" """
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head: if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.") raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
@ -518,93 +539,88 @@ class SPEECH_LLM(nn.Module):
attention_mask=merged_prompt_attention_mask, attention_mask=merged_prompt_attention_mask,
max_new_tokens=max_text_new_tokens, max_new_tokens=max_text_new_tokens,
return_dict_in_generate=True, return_dict_in_generate=True,
output_hidden_states=True,
**final_llm_kwargs **final_llm_kwargs
) )
for key in text_outputs:
print(key, text_outputs[key].shape)
# Assume text_outputs is the tensor of generated IDs [B, S_full]
generated_text_ids = text_outputs
exit(0)
# --- 3. Get LLM Hidden States for the *Full* Generated Text Sequence --- generated_text_ids = text_outputs.sequences # [B, S_full]
# Run a separate forward pass to reliably get hidden states for the complete sequence. thinker_token_embeds = [
# This is simpler than parsing the complex output of generate with output_hidden_states=True. token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
full_text_embeds = self.llm.get_input_embeddings()(generated_text_ids) # [B, S_full, D_llm] ]
full_text_attention_mask = (generated_text_ids != self.llm.config.pad_token_id).long() # [B, S_full] thinker_hidden_states = [
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
]
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
# --- 4. Project Hidden States --- thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
projected_text_embeds = self.speech_token_projector(full_text_embeds) # Shape [B, S_full, D_codec] thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
# --- 5. Generate Speech Tokens (Autoregressive Loop with Text Context) ---
self.codec_lm.to(device)
self.codec_lm_head.to(device)
# Initial input for the codec LM is the BOS token delay_step = 2
current_speech_input_ids = torch.full( thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
(batch_size, 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=device talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
) )
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
thinker_input_embeds = torch.cat(
[
thinker_prompt_part,
thinker_reply_part[:, :delay_step, :],
],
dim=1,
)
talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[:, delay_step:, :] # [B, S_full, D_codec]
past_key_values = None past_key_values = None
generated_speech_tokens_list = [[] for _ in range(batch_size)] # generated_speech_tokens_list = [[] for _ in range(batch_size)]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) # unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
generated_speech_tokens_list = []
text_context_len = projected_text_embeds.shape[1] # S_full next_token_ids = None
# text_context_len = projected_text_embeds.shape[1] # S_full
for t in range(max_speech_new_tokens): for t in range(max_speech_new_tokens):
# Get embedding for the *current* input token ID (initially BOS, then generated tokens) # Get embedding for the *current* input token ID (initially BOS, then generated tokens)
current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec] # current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
if next_token_ids is not None:
# Add the projected text embedding corresponding to the current timestep `t` talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
if t < text_context_len: if thinker_reply_part.shape[1] > 0:
# Text context from the full generated text sequence talker_inputs_embeds += thinker_reply_part[:, :1, :]
current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec] thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step
inputs_embeds = current_speech_embeds + current_text_context_embed # # Add the projected text embedding corresponding to the current timestep `t`
else: # if t < text_context_len:
# No more text context to add # # Text context from the full generated text sequence
inputs_embeds = current_speech_embeds # current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
# inputs_embeds = current_speech_embeds + current_text_context_embed
# Ensure inputs_embeds has the correct dtype for the codec_lm # else:
inputs_embeds = inputs_embeds.to(next(self.codec_lm.parameters()).dtype) # # No more text context to add
# inputs_embeds = current_speech_embeds
# Forward pass through codec LM for one step # Forward pass through codec LM for one step
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation # We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
codec_outputs = self.codec_lm( codec_outputs = self.codec_lm(
inputs_embeds=inputs_embeds, # Combined embedding for this step inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
return_dict=True, return_dict=True,
output_hidden_states=True,
# No attention mask needed here when using past_key_values and single token input # No attention mask needed here when using past_key_values and single token input
) )
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec]
# Get logits for the *last* token generated in this step # Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(codec_outputs.last_hidden_state[:, -1:, :]) # Use -1 index next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
# suppress tokens between 4096:len(vocab)-3
# --- Process Output & Update State --- next_token_logits[:, 4096:-3] = -float("Inf")
# Greedy decoding (can be replaced with sampling based on codec_lm_kwargs) next_token_ids = topk_sampling(
# TODO: Implement sampling/beam search for codec LM if needed next_token_logits,
next_token_ids = torch.argmax(next_token_logits, dim=-1) # Greedy [B, 1] )
print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
# Mask out finished sequences if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
next_token_ids = next_token_ids * unfinished_sequences.unsqueeze(-1) + \
self.codec_lm.config.pad_token_id * (1 - unfinished_sequences.unsqueeze(-1))
# Store generated tokens for unfinished sequences
for i in range(batch_size):
if unfinished_sequences[i]:
token_id = next_token_ids[i].item()
if token_id == self.codec_lm.config.eos_token_id:
unfinished_sequences[i] = 0 # Mark as finished
elif token_id != self.codec_lm.config.pad_token_id:
generated_speech_tokens_list[i].append(token_id)
# Prepare for next iteration
current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
past_key_values = codec_outputs.past_key_values # Update KV cache
# Stop if all sequences are finished
if unfinished_sequences.max() == 0:
break break
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
past_key_values = codec_outputs.past_key_values # Update KV cache
generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0])
# --- 6. Return Results --- # --- 6. Return Results ---
return generated_text_ids, generated_speech_tokens_list return generated_text_ids, generated_speech_tokens_list
@ -626,3 +642,64 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
) )
denominator = torch.sum(mask) denominator = torch.sum(mask)
return numerator.float() / denominator.float() return numerator.float() / denominator.float()
def topk_sampling(
logits,
top_k=50,
top_p=0.95,
temperature=0.8,
):
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits_filtered = top_k_top_p_filtering(
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
)
# Sample
probs = torch.nn.functional.softmax(logits_filtered, dim=-1)
tokens = torch.multinomial(probs, num_samples=1)
return tokens
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits

View File

@ -417,7 +417,7 @@ def compute_loss(
labels=target_ids.to(device), labels=target_ids.to(device),
) )
else: else:
text_loss, acc, codec_loss, codec_acc = model.forward_with_speech_output( text_loss, acc, codec_loss, codec_acc, codec_topk_acc = model.forward_with_speech_output(
fbank=feature, fbank=feature,
input_ids=input_ids.to(device), input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device), attention_mask=attention_mask.to(device),
@ -442,6 +442,9 @@ def compute_loss(
info["codec_acc"] = ( info["codec_acc"] = (
codec_acc * info["frames"] codec_acc * info["frames"]
) )
info["codec_topk_acc"] = (
codec_topk_acc * info["frames"]
)
info["codec_loss"] = codec_loss.detach().cpu().item() info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item() info["text_loss"] = text_loss.detach().cpu().item()
return loss, info return loss, info
@ -743,6 +746,7 @@ def run(rank, world_size, args):
# torch_dtype=torch_dtype, # torch_dtype=torch_dtype,
# ) # )
codec_vocab_size = 8192 codec_vocab_size = 8192
# TODO: modify above vocab size or supress_tokens when decoding
config = Qwen2Config( config = Qwen2Config(
vocab_size=codec_vocab_size, vocab_size=codec_vocab_size,
hidden_size=1024, hidden_size=1024,