update README

This commit is contained in:
root 2025-04-29 08:49:12 +00:00
parent 448a4eeea7
commit 360f0aa397
12 changed files with 423 additions and 217 deletions

View File

@ -0,0 +1,89 @@
# Introduction
This recipe includes scripts for training speech2speech models.
# SPEECH2SPEECH
The following table lists the folders for different tasks.
|Recipe | Speech Input | Speech Output | Comment|
|--------------|--------------|---------------|--------|
|Qwen-omni like| Continuous Embeddins| Cosyvoice1 50Hz Single-codebook Token | Text-driven; using Thinker LLM for text token, small Talker LLM for speech token |
### [Qwen-omni like Speech2speech Recipe](./qwen_omni)
[Qwen2.5-Omni](https://github.com/QwenLM/Qwen2.5-Omni) style model using [worstchan/Belle_1.4M-SLAM-Omni](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni) dataset.
<br>
<p align="center">
<img src="assets/framework.jpg" width="800"/>
<p>
<br>
Command for training is:
```bash
pip install -r whisper_llm_zh/requirements.txt
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
# First, we only train the projector and freeze other modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora False --unfreeze-llm False
# Then we jointly train the projector and LLM LoRA modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
```
Command for decoding:
```bash
mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 331 KiB

View File

@ -17,6 +17,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank \
--huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
--audio-key question_audio --text-key answer \
--prefix ultrachat
"""
import argparse
import logging
@ -126,7 +136,7 @@ def compute_fbank(args):
num_digits = 5
for i in range(num_shards):
shard = dataset.shard(num_shards, i)
shard = shard.take(10) # for testing
# shard = shard.take(10) # for testing
logging.info(
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
)
@ -159,8 +169,6 @@ def compute_fbank(args):
logging.info(f"Saving to {cuts_path}")
# see https://github.com/lhotse-speech/lhotse/issues/1125
cut_set.drop_recordings().to_file(cuts_path)
if i > 1:
break
def main():

View File

@ -1,11 +1,14 @@
from typing import List, Tuple # Added for type hints
import torch
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
from typing import List, Tuple # Added for type hints
from torchmetrics.classification import MulticlassAccuracy
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import logging
class EncoderProjector(nn.Module):
"""
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
@ -69,7 +72,8 @@ class SPEECH_LLM(nn.Module):
self.codec_lm = codec_lm
if self.codec_lm:
self.speech_token_projector = nn.Linear(
self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size
self.llm.config.hidden_size + self.llm.config.hidden_size,
self.codec_lm.config.hidden_size,
)
self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
@ -89,6 +93,7 @@ class SPEECH_LLM(nn.Module):
multidim_average="global",
ignore_index=IGNORE_TOKEN_ID,
)
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
):
@ -274,68 +279,115 @@ class SPEECH_LLM(nn.Module):
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
input_seq_len = attention_mask.sum(dim=1) # shape, B
text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], []
input_seq_len = attention_mask.sum(dim=1) # shape, B
(
text_label_start_index_list,
text_input_start_index_list,
input_question_len_list,
) = ([], [], [])
for i in range(labels.shape[0]):
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
input_embeds_start_index = input_embeds_valid_index[0]
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
text_labels_start_index = text_labels_valid_index[0]
assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
assert (
input_seq_len[i]
== input_embeds_valid_index[-1] - input_embeds_start_index + 1
), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
assert (
input_embeds_valid_index[-1] == text_labels_valid_index[-1]
), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
input_question_len = text_labels_start_index - input_embeds_start_index
assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i]
assert (
input_question_len
+ text_labels_valid_index[-1]
- text_labels_start_index
+ 1
== input_seq_len[i]
)
text_label_start_index_list.append(text_labels_start_index)
text_input_start_index_list.append(input_embeds_start_index)
input_question_len_list.append(input_question_len)
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
)
text_loss = model_outputs.loss
delay_step = 1
# prepare codec lm inputs
audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)]
audio_codes_lens = [
len(x) + input_question_len_list[i] + delay_step + 1
for i, x in enumerate(speech_codec_ids)
]
max_len_speech_codec = max(audio_codes_lens)
if self.codec_lm_padding_side == "right":
audio_codes = [
[self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
[self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
[self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
elif self.codec_lm_padding_side == "left":
audio_codes = [
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x
[self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id]
[self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
for i, x in enumerate(speech_codec_ids)
]
audio_codes = torch.tensor(
audio_codes,
dtype=torch.int64,
device=input_ids.device
audio_codes, dtype=torch.int64, device=input_ids.device
)
audio_labels = torch.tensor(
audio_labels,
dtype=torch.int64,
device=input_ids.device
audio_labels, dtype=torch.int64, device=input_ids.device
)
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
for i in range(len(text_label_start_index_list)):
text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1]
text_last_hidden = model_outputs.hidden_states[-1][
i,
text_input_start_index_list[i] : text_input_start_index_list[i]
+ input_seq_len[i]
- 1,
]
text_last_hidden_lists.append(text_last_hidden)
text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos
text_embed = inputs_embeds[
i,
text_input_start_index_list[i]
+ 1 : text_input_start_index_list[i]
+ input_seq_len[i],
] # exclude bos
text_embeds_list.append(text_embed)
text_input_embeds = torch.cat(
@ -344,22 +396,34 @@ class SPEECH_LLM(nn.Module):
text_embed,
],
dim=-1,
)# shape, T, D1 + D2
text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec
) # shape, T, D1 + D2
text_input_embeds = self.speech_token_projector(
text_input_embeds
) # shape, T, D_codec
text_input_embeds_list.append(text_input_embeds)
for i in range(audio_embeddings.shape[0]):
text_input_embeds = text_input_embeds_list[i]
if self.codec_lm_padding_side == "right":
audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds
audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
elif self.codec_lm_padding_side == "left":
start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0]
start_idx = torch.where(
audio_codes[i] == self.codec_lm.config.mask_token_id
)[0][0]
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
assert (
start_idx == start_idx_re_compute
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx]
logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}")
audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds
text_input_embeds = text_input_embeds[
: audio_embeddings.shape[1] - start_idx
]
logging.warning(
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
)
audio_embeddings[
i, start_idx : start_idx + text_input_embeds.shape[0]
] += text_input_embeds
speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask,
@ -369,8 +433,10 @@ class SPEECH_LLM(nn.Module):
)
last_hidden_state = speech_outputs.hidden_states[-1].clone()
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
audio_logits = audio_logits.contiguous().view(-1, self.codec_lm.config.vocab_size)
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
audio_logits = audio_logits.contiguous().view(
-1, self.codec_lm.config.vocab_size
)
audio_labels = audio_labels.contiguous().view(-1)
audio_labels = audio_labels.masked_fill(
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
@ -378,7 +444,6 @@ class SPEECH_LLM(nn.Module):
codec_loss = self.loss_fct(audio_logits, audio_labels)
audio_preds = torch.argmax(audio_logits, -1)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(
@ -392,12 +457,11 @@ class SPEECH_LLM(nn.Module):
ignore_label=IGNORE_TOKEN_ID,
)
audio_topk_acc = self.audio_accuracy_metric(
audio_logits.detach(),
audio_labels.detach()).item()
audio_logits.detach(), audio_labels.detach()
).item()
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
def decode(
self,
fbank: torch.Tensor = None,
@ -453,12 +517,12 @@ class SPEECH_LLM(nn.Module):
def decode_with_speech_output(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None, # Prompt input_ids
attention_mask: torch.Tensor = None, # Prompt attention_mask
input_ids: torch.LongTensor = None, # Prompt input_ids
attention_mask: torch.Tensor = None, # Prompt attention_mask
max_text_new_tokens: int = 1024,
max_speech_new_tokens: int = 1024, # Max length for speech tokens
llm_kwargs: dict = None, # Kwargs for text LLM generate
codec_lm_kwargs: dict = None # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
max_speech_new_tokens: int = 1024, # Max length for speech tokens
llm_kwargs: dict = None, # Kwargs for text LLM generate
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
) -> Tuple[torch.LongTensor, List[List[int]]]:
"""
Generates text and corresponding speech tokens using the revised logic.
@ -479,16 +543,22 @@ class SPEECH_LLM(nn.Module):
the generated speech codec tokens for a batch item.
"""
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
if (
not self.codec_lm
or not self.speech_token_projector
or not self.codec_lm_head
):
raise ValueError(
"codec_lm and associated layers must be initialized to generate speech output."
)
device = next(self.parameters()).device # Use model's device
device = next(self.parameters()).device # Use model's device
batch_size = fbank.shape[0]
# --- 1. Prepare Prompt Embeddings ---
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
@ -511,12 +581,12 @@ class SPEECH_LLM(nn.Module):
"eos_token_id": self.llm.config.eos_token_id,
"pad_token_id": self.llm.config.pad_token_id,
"num_beams": 1,
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
"top_p": 0.5,
"top_k": 20,
"repetition_penalty": 1.1,
"temperature": 0.7,
**(llm_kwargs or {}) # User-provided kwargs override defaults
**(llm_kwargs or {}), # User-provided kwargs override defaults
}
text_outputs = self.llm.generate(
@ -525,17 +595,22 @@ class SPEECH_LLM(nn.Module):
max_new_tokens=max_text_new_tokens,
return_dict_in_generate=True,
output_hidden_states=True,
**final_llm_kwargs
**final_llm_kwargs,
)
delay_step = 1
generated_text_ids = text_outputs.sequences # [B, S_full]
generated_text_ids = text_outputs.sequences # [B, S_full]
eos_token_id = self.llm.config.eos_token_id
eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D
assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
eos_token_embedding = self.llm.get_input_embeddings()(
torch.tensor([[eos_token_id]], device=device)
) # 1,D
assert (
generated_text_ids[0, -1] == eos_token_id
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
thinker_token_embeds_org = [
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
token_hidden_states[0].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
first_thinker_token_embed = torch.cat(
[
thinker_token_embeds_org[0][:, 1:],
@ -544,19 +619,27 @@ class SPEECH_LLM(nn.Module):
dim=1,
)
thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding]
thinker_token_embeds = (
[first_thinker_token_embed]
+ thinker_token_embeds_org[2:]
+ [eos_token_embedding]
)
thinker_hidden_states = [
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
token_hidden_states[-1].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
thinker_reply_part = [torch.cat(
[
thinker_hidden_state,
thinker_token_embed,
],
dim=-1,
)
for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:])
thinker_reply_part = [
torch.cat(
[
thinker_hidden_state,
thinker_token_embed,
],
dim=-1,
)
for thinker_hidden_state, thinker_token_embed in zip(
thinker_hidden_states[1:], thinker_token_embeds[1:]
)
]
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
@ -568,26 +651,35 @@ class SPEECH_LLM(nn.Module):
dim=-1,
)
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
thinker_prompt_part = self.speech_token_projector(
thinker_prompt_part
) # [B, S_full, D_codec]
thinker_reply_part = self.speech_token_projector(
thinker_reply_part
) # [B, S_full, D_codec]
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device
(batch_size, thinker_prompt_part_seq_len + delay_step + 1),
self.codec_lm.config.mask_token_id,
dtype=torch.long,
device=self.llm.device,
)
talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
talker_input_ids
) # [B, S_full, D_codec]
thinker_input_embeds = torch.cat(
[
thinker_prompt_part,
thinker_reply_part[:, :delay_step + 1, :],
thinker_reply_part[:, : delay_step + 1, :],
],
dim=1,
)
talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec]
thinker_reply_part = thinker_reply_part[
:, delay_step + 1 :, :
] # [B, S_full, D_codec]
past_key_values = None
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
@ -599,10 +691,14 @@ class SPEECH_LLM(nn.Module):
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
if t > 0:
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
next_token_ids
) # [B, 1, D_codec]
if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :]
thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step
thinker_reply_part = thinker_reply_part[
:, 1:, :
] # Remove the first token for next step
# # Add the projected text embedding corresponding to the current timestep `t`
# if t < text_context_len:
# # Text context from the full generated text sequence
@ -611,20 +707,24 @@ class SPEECH_LLM(nn.Module):
# else:
# # No more text context to add
# inputs_embeds = current_speech_embeds
# Forward pass through codec LM for one step
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
codec_outputs = self.codec_lm(
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
output_hidden_states=True,
# No attention mask needed here when using past_key_values and single token input
)
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here
last_token_hidden_state = codec_outputs.hidden_states[-1][
:, -1, :
] # [B, D_codec] #TODO: check shape here
# Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
next_token_logits = self.codec_lm_head(
last_token_hidden_state
) # Use -1 index
# suppress tokens between 4096:len(vocab)-3
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
next_token_ids = topk_sampling(
@ -634,11 +734,14 @@ class SPEECH_LLM(nn.Module):
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
break
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
past_key_values = codec_outputs.past_key_values # Update KV cache
generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0])
past_key_values = codec_outputs.past_key_values # Update KV cache
generated_speech_tokens_list.append(
next_token_ids.squeeze(1).cpu().tolist()[0]
)
# --- 6. Return Results ---
return generated_text_ids, generated_speech_tokens_list
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
@ -717,4 +820,4 @@ def top_k_top_p_filtering(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
return logits

View File

@ -1,13 +1,12 @@
from typing import Callable, Dict, List, Union
import torch
from torch.utils.data.dataloader import DataLoader, default_collate
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from lhotse.workarounds import Hdf5MemoryIssueFix
from torch.utils.data.dataloader import DataLoader, default_collate
class K2SpeechRecognitionDataset(torch.utils.data.Dataset):

View File

@ -1,26 +1,25 @@
# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
import io
import numpy as np
import gradio as gr
import soundfile as sf
import gradio.processing_utils as processing_utils
import tempfile
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from gradio_client import utils as client_utils
from argparse import ArgumentParser
import whisper
import torch
from peft import LoraConfig, get_peft_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from model import SPEECH_LLM, EncoderProjector
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
import sherpa_onnx
from cosyvoice.cli.cosyvoice import CosyVoice
import sys
sys.path.append('/workspace/CosyVoice/third_party/Matcha-TTS')
from argparse import ArgumentParser
import gradio as gr
import gradio.processing_utils as processing_utils
import numpy as np
import sherpa_onnx
import soundfile as sf
import torch
import whisper
from cosyvoice.cli.cosyvoice import CosyVoice
from gradio_client import utils as client_utils
from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def get_model(params, device="cuda"):
@ -88,7 +87,7 @@ def get_model(params, device="cuda"):
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch.float16
torch_dtype=torch.float16,
)
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
@ -102,12 +101,10 @@ def get_model(params, device="cuda"):
llm,
encoder_projector,
codec_lm,
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
codec_lm_padding_side="left" if params.use_flash_attn else "right",
)
checkpoint = torch.load(
f"{params.checkpoint_path}", map_location="cpu"
)
checkpoint = torch.load(f"{params.checkpoint_path}", map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
model.to(device)
@ -122,27 +119,37 @@ def audio_decode_cosyvoice(audio_tokens, codec_decoder):
Args:
audio_tokens (list): List of audio tokens to be processed.
codec_decoder: Codec decoder for generating audio.
Returns:
torch.Tensor: Generated audio waveform.
"""
flow_embedding = codec_decoder.frontend.spk2info['中文女']['embedding']
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
prompt_speech_feat = torch.zeros(1, 0, 80)
tts_mel, _ = codec_decoder.model.flow.inference(token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
embedding=flow_embedding.to(codec_decoder.model.device),
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
prompt_token_len=torch.tensor(
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
prompt_feat_len=torch.tensor(
[prompt_speech_feat.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
embedding=flow_embedding.to(codec_decoder.model.device),
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
)
audio_hat, _ = codec_decoder.model.hift.inference(speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0))
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def preprocess(
messages,
tokenizer,
@ -178,28 +185,14 @@ def preprocess(
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def format_history(history: list):
messages = []
for item in history:
if isinstance(item["content"], str):
messages.append({"role": item['role'], "content": item['content']})
# elif item["role"] == "user" and (isinstance(item["content"], list) or
# isinstance(item["content"], tuple)):
# file_path = item["content"][0]
# # TODO: check if the file_path's transcript is already in the history
# mime_type = client_utils.get_mimetype(file_path)
# if mime_type.startswith("audio"):
# messages.append({
# "role":
# item['role'],
# "content": item["content"][1] # append audio transcript here
# })
print('predict history: ', messages)
# messages = messages[-2:] # TODO: WAR: add history later
messages.append({"role": item["role"], "content": item["content"]})
return messages
def decode(
@ -217,9 +210,8 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
dtype = torch.float32
device = model.llm.device
feature = feature.to(device, dtype=dtype)#.transpose(1, 2)
# assert feature.shape[2] == 80
feature = feature.to(device, dtype=dtype)
input_ids, attention_mask = preprocess([messages], tokenizer)
generated_ids, audio_tokens = model.decode_with_speech_output(
@ -227,26 +219,21 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# print('hyps: ', hyps, 23333333333333333333333333)
yield {"type": "text", "data": hyps[0]}
# yield {"type": "text", "data": hyps}
audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
audio = audio_hat.squeeze(0).cpu().numpy()
# sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050)
audio = audio_hat.squeeze(0).cpu().numpy()
audio = np.array(audio * 32767).astype(np.int16)
# yield {"type": "audio", "data": (22050, audio)}
# with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
# sf.write(tmpfile.name, audio, 22050, format="WAV")
# audio_path = tmpfile.name
wav_io = io.BytesIO()
sf.write(wav_io, audio, samplerate=22050, format="WAV")
wav_io.seek(0)
wav_bytes = wav_io.getvalue()
audio_path = processing_utils.save_bytes_to_cache(
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE
)
yield {"type": "audio", "data": audio_path}
@ -259,25 +246,27 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
gr.update(visible=True), # stop_btn
)
print(2333, history, audio)
history.append({"role": "user", "content": (audio,)})
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
history.append({"role": "user", "content": (audio,)})
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
history.append({"role": "assistant", "content": ""})
formatted_history = format_history(history=history) # only keep string text format
formatted_history = format_history(
history=history
) # only keep string text format
assert audio is not None
audio_transcript = get_transcript(
audio,
asr_model,
)
print('audio_transcript: ', audio_transcript)
history[-2]["content"] = audio_transcript
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
fbank = fbank.unsqueeze(0)
assert fbank.ndim == 3
# history.append({"role": "assistant", "content": ""})
for chunk in decode(model, token2wav_model, tokenizer, fbank, formatted_history):
for chunk in decode(
model, token2wav_model, tokenizer, fbank, formatted_history
):
if chunk["type"] == "text":
history[-1]["content"] = chunk["data"]
yield (
@ -287,10 +276,9 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
gr.update(visible=True), # stop_btn
)
if chunk["type"] == "audio":
history.append({
"role": "assistant",
"content": gr.Audio(chunk["data"])
})
history.append(
{"role": "assistant", "content": gr.Audio(chunk["data"])}
)
# Final yield
yield (
@ -304,8 +292,7 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
with gr.Tab("Online"):
with gr.Row():
with gr.Column(scale=1):
microphone = gr.Audio(sources=['microphone'],
type="filepath")
microphone = gr.Audio(sources=["microphone"], type="filepath")
submit_btn = gr.Button("Submit", variant="primary")
stop_btn = gr.Button("Stop", visible=False)
clear_btn = gr.Button("Clear History")
@ -315,64 +302,80 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def clear_history():
return [], gr.update(value=None)
submit_event = submit_btn.click(fn=media_predict,
inputs=[
microphone,
media_chatbot,
],
outputs=[
microphone,
media_chatbot, submit_btn,
stop_btn
])
submit_event = submit_btn.click(
fn=media_predict,
inputs=[
microphone,
media_chatbot,
],
outputs=[microphone, media_chatbot, submit_btn, stop_btn],
)
stop_btn.click(
fn=lambda:
(gr.update(visible=True), gr.update(visible=False)),
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
inputs=None,
outputs=[submit_btn, stop_btn],
cancels=[submit_event],
queue=False)
clear_btn.click(fn=clear_history,
inputs=None,
outputs=[media_chatbot, microphone])
queue=False,
)
clear_btn.click(
fn=clear_history, inputs=None, outputs=[media_chatbot, microphone]
)
demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100,
ssr_mode=False,
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,)
demo.queue(default_concurrency_limit=100, max_size=100).launch(
max_threads=100,
ssr_mode=False,
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
def _get_args():
parser = ArgumentParser()
parser.add_argument('--checkpoint-path',
type=str,
default=None,
help='Checkpoint name or path, default to %(default)r')
parser.add_argument('--token2wav-path',
type=str,
default=None,
help='Token2Wav path, default to %(default)r')
parser.add_argument('--asr-model-dir',
type=str,
default=None,
help='ASR model dir, default to %(default)r')
parser.add_argument('--flash-attn2',
action='store_true',
default=False,
help='Enable flash_attention_2 when loading the model.')
parser.add_argument('--share',
action='store_true',
default=False,
help='Create a publicly shareable link for the interface.')
parser.add_argument('--inbrowser',
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
parser.add_argument('--server-port', type=int, default=8001, help='Demo server port.')
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
parser.add_argument(
"--checkpoint-path",
type=str,
default=None,
help="Checkpoint name or path, default to %(default)r",
)
parser.add_argument(
"--token2wav-path",
type=str,
default=None,
help="Token2Wav path, default to %(default)r",
)
parser.add_argument(
"--asr-model-dir",
type=str,
default=None,
help="ASR model dir, default to %(default)r",
)
parser.add_argument(
"--flash-attn2",
action="store_true",
default=False,
help="Enable flash_attention_2 when loading the model.",
)
parser.add_argument(
"--share",
action="store_true",
default=False,
help="Create a publicly shareable link for the interface.",
)
parser.add_argument(
"--inbrowser",
action="store_true",
default=False,
help="Automatically launch the interface in a new tab on the default browser.",
)
parser.add_argument(
"--server-port", type=int, default=8001, help="Demo server port."
)
parser.add_argument(
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
)
add_model_arguments(parser)
args = parser.parse_args()
return args
@ -401,6 +404,7 @@ def read_wave(wave_filename: str):
return samples_float32, sample_rate
def get_transcript(audio_path, recognizer):
samples, sample_rate = read_wave(audio_path)
s = recognizer.create_stream()
@ -408,10 +412,13 @@ def get_transcript(audio_path, recognizer):
recognizer.decode_streams([s])
return s.result.text
if __name__ == "__main__":
args = _get_args()
model, tokenizer = get_model(args)
token2wav = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False)
token2wav = CosyVoice(
args.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=f"{args.asr_model_dir}/model.int8.onnx",
@ -423,4 +430,4 @@ if __name__ == "__main__":
debug=False,
)
_launch_demo(args, model, tokenizer, token2wav, asr_model)
_launch_demo(args, model, tokenizer, token2wav, asr_model)