mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
fix bugs when padding right
This commit is contained in:
parent
23fdef2fd3
commit
478d56efd8
@ -35,8 +35,8 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 3: Combine features"
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Combine features"
|
||||
manifest_dir=data/fbank
|
||||
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
|
||||
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
|
||||
@ -48,39 +48,27 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "stage 2: "
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: "
|
||||
python3 ./slam_omni/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir slam_omni/exp_speech2text \
|
||||
--max-duration 1 \
|
||||
--exp-dir slam_omni/exp_speech2speech_test_flash_attn \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 999 --avg 1 \
|
||||
--epoch 997 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method pure_text_sampling \
|
||||
--method small_test_speech2speech \
|
||||
--enable-speech-output True \
|
||||
--use-lora True # --on-the-fly-feats True
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "stage 2: "
|
||||
python3 ./slam_omni/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir slam_omni/exp_speech2text \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method pure_text_sampling_original_0.5B \
|
||||
--use-lora False # --on-the-fly-feats True
|
||||
|
||||
fi
|
||||
|
||||
ngpu=8
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: "
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "stage 4: "
|
||||
ngpu=8
|
||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--max-duration 80 \
|
||||
--enable-musan False \
|
||||
@ -97,21 +85,23 @@ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "stage 4: "
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "stage 5: "
|
||||
ngpu=2
|
||||
exp_dir=./slam_omni/exp_speech2speech_test_flash_attn
|
||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--max-duration 40 \
|
||||
--enable-musan False \
|
||||
--exp-dir ./slam_omni/exp_speech2text \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||
--use-flash-attn False \
|
||||
--use-lora True --unfreeze-llm False --enable-speech-output True
|
||||
--use-flash-attn True \
|
||||
--pretrained-model-path $exp_dir/epoch-1-checkpoint-35000.pt/pytorch_model.bin \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||
# --sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
|
||||
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
|
||||
|
||||
fi
|
@ -62,9 +62,9 @@ from model import SPEECH_LLM, EncoderProjector
|
||||
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from train import DEFAULT_SPEECH_TOKEN
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from train import add_model_arguments
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -126,43 +126,6 @@ def average_checkpoints(
|
||||
return avg
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--llm-path-or-name",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path or name of the large language model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--speech-encoder-path-or-name",
|
||||
type=str,
|
||||
default="whisper-large-v2",
|
||||
help="Path or name of the speech encoder.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-projector-ds-rate",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Downsample rate for the encoder projector.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-flash-attn",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use flash attention.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-lora",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use lora fine-tuned llm checkpoint.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -351,16 +314,21 @@ def decode_one_batch(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
for cut_id, speech_output in zip(cut_ids, generated_speech_output):
|
||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||
print(f"speech_output: {speech_output}, cut_id: {cut_id}")
|
||||
with open("test.txt", 'w') as f:
|
||||
for cut_id in cut_ids:
|
||||
# save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||
#torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||
print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||
save_str = " ".join([str(i) for i in generated_speech_output])
|
||||
f.write(f"{cut_id}|{save_str}\n")
|
||||
|
||||
else:
|
||||
generated_ids = model.decode(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
|
||||
print(f"hyps: {hyps}")
|
||||
exit(0)
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
from typing import List, Tuple # Added for type hints
|
||||
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
|
||||
|
||||
@ -74,9 +74,21 @@ class SPEECH_LLM(nn.Module):
|
||||
self.codec_lm_head = nn.Linear(
|
||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||
)
|
||||
# to torch.float16
|
||||
self.speech_token_projector = self.speech_token_projector.to(
|
||||
dtype=torch.float16
|
||||
)
|
||||
self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16)
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss()
|
||||
self.codec_lm_padding_side = codec_lm_padding_side
|
||||
|
||||
self.audio_accuracy_metric = MulticlassAccuracy(
|
||||
self.codec_lm.vocab_size,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index=IGNORE_TOKEN_ID,
|
||||
)
|
||||
def _merge_input_ids_with_speech_features(
|
||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||
):
|
||||
@ -332,7 +344,7 @@ class SPEECH_LLM(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||
|
||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) # TODO: do we need to change bos tokens to pad token or mask token?
|
||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||
|
||||
# input_ids: seq_len T1, audio_codec seq_len T2
|
||||
@ -355,7 +367,12 @@ class SPEECH_LLM(nn.Module):
|
||||
start_idx_content = T_audio - item_len # Start index of the content for item i
|
||||
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
|
||||
# Add the text_input_embeds to the calculated slice
|
||||
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
|
||||
if end_idx_target > T_audio:
|
||||
# If the text input is longer than the audio input, we need to pad the audio input
|
||||
cut_off_len = T_audio - start_idx_content
|
||||
audio_embeddings[i, start_idx_content:end_idx_target] = text_input_embeds[i, :cut_off_len]
|
||||
else:
|
||||
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
|
||||
else:
|
||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||
|
||||
@ -389,9 +406,12 @@ class SPEECH_LLM(nn.Module):
|
||||
audio_labels.detach(),
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
audio_topk_acc = self.audio_accuracy_metric(
|
||||
audio_logits.detach(),
|
||||
audio_labels.detach()).item()
|
||||
|
||||
|
||||
return text_loss, acc, codec_loss, audio_acc
|
||||
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
|
||||
|
||||
def decode(
|
||||
self,
|
||||
@ -473,6 +493,7 @@ class SPEECH_LLM(nn.Module):
|
||||
- generated_speech_tokens: List of lists, where each inner list contains
|
||||
the generated speech codec tokens for a batch item.
|
||||
"""
|
||||
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
|
||||
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
|
||||
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
|
||||
|
||||
@ -518,93 +539,88 @@ class SPEECH_LLM(nn.Module):
|
||||
attention_mask=merged_prompt_attention_mask,
|
||||
max_new_tokens=max_text_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_hidden_states=True,
|
||||
**final_llm_kwargs
|
||||
)
|
||||
for key in text_outputs:
|
||||
print(key, text_outputs[key].shape)
|
||||
# Assume text_outputs is the tensor of generated IDs [B, S_full]
|
||||
generated_text_ids = text_outputs
|
||||
exit(0)
|
||||
|
||||
# --- 3. Get LLM Hidden States for the *Full* Generated Text Sequence ---
|
||||
# Run a separate forward pass to reliably get hidden states for the complete sequence.
|
||||
# This is simpler than parsing the complex output of generate with output_hidden_states=True.
|
||||
full_text_embeds = self.llm.get_input_embeddings()(generated_text_ids) # [B, S_full, D_llm]
|
||||
full_text_attention_mask = (generated_text_ids != self.llm.config.pad_token_id).long() # [B, S_full]
|
||||
generated_text_ids = text_outputs.sequences # [B, S_full]
|
||||
thinker_token_embeds = [
|
||||
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
thinker_hidden_states = [
|
||||
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||
thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||
|
||||
# --- 4. Project Hidden States ---
|
||||
projected_text_embeds = self.speech_token_projector(full_text_embeds) # Shape [B, S_full, D_codec]
|
||||
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
|
||||
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
|
||||
|
||||
# --- 5. Generate Speech Tokens (Autoregressive Loop with Text Context) ---
|
||||
self.codec_lm.to(device)
|
||||
self.codec_lm_head.to(device)
|
||||
|
||||
# Initial input for the codec LM is the BOS token
|
||||
current_speech_input_ids = torch.full(
|
||||
(batch_size, 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=device
|
||||
delay_step = 2
|
||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||
talker_input_ids = torch.full(
|
||||
(batch_size, thinker_prompt_part_seq_len + delay_step), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device
|
||||
)
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
||||
thinker_input_embeds = torch.cat(
|
||||
[
|
||||
thinker_prompt_part,
|
||||
thinker_reply_part[:, :delay_step, :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
talker_inputs_embeds += thinker_input_embeds
|
||||
thinker_reply_part = thinker_reply_part[:, delay_step:, :] # [B, S_full, D_codec]
|
||||
|
||||
|
||||
past_key_values = None
|
||||
generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
text_context_len = projected_text_embeds.shape[1] # S_full
|
||||
|
||||
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
||||
# unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
||||
generated_speech_tokens_list = []
|
||||
next_token_ids = None
|
||||
# text_context_len = projected_text_embeds.shape[1] # S_full
|
||||
for t in range(max_speech_new_tokens):
|
||||
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
||||
current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
||||
|
||||
# Add the projected text embedding corresponding to the current timestep `t`
|
||||
if t < text_context_len:
|
||||
# Text context from the full generated text sequence
|
||||
current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
|
||||
inputs_embeds = current_speech_embeds + current_text_context_embed
|
||||
else:
|
||||
# No more text context to add
|
||||
inputs_embeds = current_speech_embeds
|
||||
|
||||
# Ensure inputs_embeds has the correct dtype for the codec_lm
|
||||
inputs_embeds = inputs_embeds.to(next(self.codec_lm.parameters()).dtype)
|
||||
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
||||
if next_token_ids is not None:
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
|
||||
if thinker_reply_part.shape[1] > 0:
|
||||
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
||||
thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step
|
||||
# # Add the projected text embedding corresponding to the current timestep `t`
|
||||
# if t < text_context_len:
|
||||
# # Text context from the full generated text sequence
|
||||
# current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
|
||||
# inputs_embeds = current_speech_embeds + current_text_context_embed
|
||||
# else:
|
||||
# # No more text context to add
|
||||
# inputs_embeds = current_speech_embeds
|
||||
|
||||
# Forward pass through codec LM for one step
|
||||
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
|
||||
codec_outputs = self.codec_lm(
|
||||
inputs_embeds=inputs_embeds, # Combined embedding for this step
|
||||
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
# No attention mask needed here when using past_key_values and single token input
|
||||
)
|
||||
|
||||
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec]
|
||||
# Get logits for the *last* token generated in this step
|
||||
next_token_logits = self.codec_lm_head(codec_outputs.last_hidden_state[:, -1:, :]) # Use -1 index
|
||||
|
||||
# --- Process Output & Update State ---
|
||||
# Greedy decoding (can be replaced with sampling based on codec_lm_kwargs)
|
||||
# TODO: Implement sampling/beam search for codec LM if needed
|
||||
next_token_ids = torch.argmax(next_token_logits, dim=-1) # Greedy [B, 1]
|
||||
|
||||
# Mask out finished sequences
|
||||
next_token_ids = next_token_ids * unfinished_sequences.unsqueeze(-1) + \
|
||||
self.codec_lm.config.pad_token_id * (1 - unfinished_sequences.unsqueeze(-1))
|
||||
|
||||
# Store generated tokens for unfinished sequences
|
||||
for i in range(batch_size):
|
||||
if unfinished_sequences[i]:
|
||||
token_id = next_token_ids[i].item()
|
||||
if token_id == self.codec_lm.config.eos_token_id:
|
||||
unfinished_sequences[i] = 0 # Mark as finished
|
||||
elif token_id != self.codec_lm.config.pad_token_id:
|
||||
generated_speech_tokens_list[i].append(token_id)
|
||||
|
||||
# Prepare for next iteration
|
||||
current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
||||
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||
|
||||
# Stop if all sequences are finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
||||
# suppress tokens between 4096:len(vocab)-3
|
||||
next_token_logits[:, 4096:-3] = -float("Inf")
|
||||
next_token_ids = topk_sampling(
|
||||
next_token_logits,
|
||||
)
|
||||
print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
|
||||
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
||||
break
|
||||
|
||||
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
||||
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||
generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0])
|
||||
# --- 6. Return Results ---
|
||||
return generated_text_ids, generated_speech_tokens_list
|
||||
|
||||
@ -626,3 +642,64 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
)
|
||||
denominator = torch.sum(mask)
|
||||
return numerator.float() / denominator.float()
|
||||
|
||||
|
||||
def topk_sampling(
|
||||
logits,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
temperature=0.8,
|
||||
):
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
# Top-p/top-k filtering
|
||||
logits_filtered = top_k_top_p_filtering(
|
||||
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
)
|
||||
# Sample
|
||||
probs = torch.nn.functional.softmax(logits_filtered, dim=-1)
|
||||
tokens = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||
def top_k_top_p_filtering(
|
||||
logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1
|
||||
):
|
||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
if top_k > 0:
|
||||
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
if min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
@ -417,7 +417,7 @@ def compute_loss(
|
||||
labels=target_ids.to(device),
|
||||
)
|
||||
else:
|
||||
text_loss, acc, codec_loss, codec_acc = model.forward_with_speech_output(
|
||||
text_loss, acc, codec_loss, codec_acc, codec_topk_acc = model.forward_with_speech_output(
|
||||
fbank=feature,
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
@ -442,6 +442,9 @@ def compute_loss(
|
||||
info["codec_acc"] = (
|
||||
codec_acc * info["frames"]
|
||||
)
|
||||
info["codec_topk_acc"] = (
|
||||
codec_topk_acc * info["frames"]
|
||||
)
|
||||
info["codec_loss"] = codec_loss.detach().cpu().item()
|
||||
info["text_loss"] = text_loss.detach().cpu().item()
|
||||
return loss, info
|
||||
@ -743,6 +746,7 @@ def run(rank, world_size, args):
|
||||
# torch_dtype=torch_dtype,
|
||||
# )
|
||||
codec_vocab_size = 8192
|
||||
# TODO: modify above vocab size or supress_tokens when decoding
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
hidden_size=1024,
|
||||
|
Loading…
x
Reference in New Issue
Block a user