mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
update README
This commit is contained in:
parent
448a4eeea7
commit
360f0aa397
89
egs/speech_llm/SPEECH2SPEECH/README.md
Normal file
89
egs/speech_llm/SPEECH2SPEECH/README.md
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
|
||||||
|
# Introduction
|
||||||
|
|
||||||
|
This recipe includes scripts for training speech2speech models.
|
||||||
|
|
||||||
|
# SPEECH2SPEECH
|
||||||
|
|
||||||
|
The following table lists the folders for different tasks.
|
||||||
|
|
||||||
|
|Recipe | Speech Input | Speech Output | Comment|
|
||||||
|
|--------------|--------------|---------------|--------|
|
||||||
|
|Qwen-omni like| Continuous Embeddins| Cosyvoice1 50Hz Single-codebook Token | Text-driven; using Thinker LLM for text token, small Talker LLM for speech token |
|
||||||
|
|
||||||
|
### [Qwen-omni like Speech2speech Recipe](./qwen_omni)
|
||||||
|
|
||||||
|
[Qwen2.5-Omni](https://github.com/QwenLM/Qwen2.5-Omni) style model using [worstchan/Belle_1.4M-SLAM-Omni](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni) dataset.
|
||||||
|
|
||||||
|
<br>
|
||||||
|
<p align="center">
|
||||||
|
<img src="assets/framework.jpg" width="800"/>
|
||||||
|
<p>
|
||||||
|
<br>
|
||||||
|
|
||||||
|
Command for training is:
|
||||||
|
```bash
|
||||||
|
pip install -r whisper_llm_zh/requirements.txt
|
||||||
|
|
||||||
|
pip install huggingface_hub['cli']
|
||||||
|
mkdir -p models/whisper models/qwen
|
||||||
|
|
||||||
|
# For aishell fine-tuned whisper model
|
||||||
|
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||||
|
# For multi-hans fine-tuned whisper model
|
||||||
|
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||||
|
|
||||||
|
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||||
|
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
||||||
|
|
||||||
|
# First, we only train the projector and freeze other modules.
|
||||||
|
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||||
|
--max-duration 200 \
|
||||||
|
--exp-dir ./whisper_llm_zh/exp_test \
|
||||||
|
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||||
|
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--use-lora False --unfreeze-llm False
|
||||||
|
|
||||||
|
# Then we jointly train the projector and LLM LoRA modules.
|
||||||
|
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||||
|
--max-duration 200 \
|
||||||
|
--exp-dir ./whisper_llm_zh/exp_test \
|
||||||
|
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||||
|
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--use-lora True --unfreeze-llm True
|
||||||
|
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
Command for decoding:
|
||||||
|
```bash
|
||||||
|
mkdir -p models/whisper models/qwen models/checkpoint
|
||||||
|
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||||
|
|
||||||
|
# For aishell fine-tuned whisper model
|
||||||
|
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||||
|
# For multi-hans fine-tuned whisper model
|
||||||
|
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||||
|
|
||||||
|
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||||
|
|
||||||
|
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||||
|
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||||
|
|
||||||
|
python3 ./whisper_llm_zh/decode.py \
|
||||||
|
--max-duration 80 \
|
||||||
|
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
||||||
|
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||||
|
--llm-path-or-name models/qwen \
|
||||||
|
--epoch 999 --avg 1 \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--use-lora True --dataset aishell
|
||||||
|
```
|
BIN
egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg
Normal file
BIN
egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 331 KiB |
@ -17,6 +17,16 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 local/compute_whisper_fbank.py \
|
||||||
|
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||||
|
--out-dir data/fbank \
|
||||||
|
--huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
|
||||||
|
--audio-key question_audio --text-key answer \
|
||||||
|
--prefix ultrachat
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -126,7 +136,7 @@ def compute_fbank(args):
|
|||||||
num_digits = 5
|
num_digits = 5
|
||||||
for i in range(num_shards):
|
for i in range(num_shards):
|
||||||
shard = dataset.shard(num_shards, i)
|
shard = dataset.shard(num_shards, i)
|
||||||
shard = shard.take(10) # for testing
|
# shard = shard.take(10) # for testing
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
|
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
|
||||||
)
|
)
|
||||||
@ -159,8 +169,6 @@ def compute_fbank(args):
|
|||||||
logging.info(f"Saving to {cuts_path}")
|
logging.info(f"Saving to {cuts_path}")
|
||||||
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
||||||
cut_set.drop_recordings().to_file(cuts_path)
|
cut_set.drop_recordings().to_file(cuts_path)
|
||||||
if i > 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
|
from typing import List, Tuple # Added for type hints
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.trainer_pt_utils import LabelSmoother
|
|
||||||
from typing import List, Tuple # Added for type hints
|
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class EncoderProjector(nn.Module):
|
class EncoderProjector(nn.Module):
|
||||||
"""
|
"""
|
||||||
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
|
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
|
||||||
@ -69,7 +72,8 @@ class SPEECH_LLM(nn.Module):
|
|||||||
self.codec_lm = codec_lm
|
self.codec_lm = codec_lm
|
||||||
if self.codec_lm:
|
if self.codec_lm:
|
||||||
self.speech_token_projector = nn.Linear(
|
self.speech_token_projector = nn.Linear(
|
||||||
self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size
|
self.llm.config.hidden_size + self.llm.config.hidden_size,
|
||||||
|
self.codec_lm.config.hidden_size,
|
||||||
)
|
)
|
||||||
self.codec_lm_head = nn.Linear(
|
self.codec_lm_head = nn.Linear(
|
||||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||||
@ -89,6 +93,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
multidim_average="global",
|
multidim_average="global",
|
||||||
ignore_index=IGNORE_TOKEN_ID,
|
ignore_index=IGNORE_TOKEN_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _merge_input_ids_with_speech_features(
|
def _merge_input_ids_with_speech_features(
|
||||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||||
):
|
):
|
||||||
@ -274,68 +279,115 @@ class SPEECH_LLM(nn.Module):
|
|||||||
) = self._merge_input_ids_with_speech_features(
|
) = self._merge_input_ids_with_speech_features(
|
||||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
input_seq_len = attention_mask.sum(dim=1) # shape, B
|
input_seq_len = attention_mask.sum(dim=1) # shape, B
|
||||||
text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], []
|
(
|
||||||
|
text_label_start_index_list,
|
||||||
|
text_input_start_index_list,
|
||||||
|
input_question_len_list,
|
||||||
|
) = ([], [], [])
|
||||||
for i in range(labels.shape[0]):
|
for i in range(labels.shape[0]):
|
||||||
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
|
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
|
||||||
input_embeds_start_index = input_embeds_valid_index[0]
|
input_embeds_start_index = input_embeds_valid_index[0]
|
||||||
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
|
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
|
||||||
text_labels_start_index = text_labels_valid_index[0]
|
text_labels_start_index = text_labels_valid_index[0]
|
||||||
|
|
||||||
assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
|
assert (
|
||||||
assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
|
input_seq_len[i]
|
||||||
|
== input_embeds_valid_index[-1] - input_embeds_start_index + 1
|
||||||
|
), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
|
||||||
|
assert (
|
||||||
|
input_embeds_valid_index[-1] == text_labels_valid_index[-1]
|
||||||
|
), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
|
||||||
input_question_len = text_labels_start_index - input_embeds_start_index
|
input_question_len = text_labels_start_index - input_embeds_start_index
|
||||||
assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i]
|
assert (
|
||||||
|
input_question_len
|
||||||
|
+ text_labels_valid_index[-1]
|
||||||
|
- text_labels_start_index
|
||||||
|
+ 1
|
||||||
|
== input_seq_len[i]
|
||||||
|
)
|
||||||
text_label_start_index_list.append(text_labels_start_index)
|
text_label_start_index_list.append(text_labels_start_index)
|
||||||
text_input_start_index_list.append(input_embeds_start_index)
|
text_input_start_index_list.append(input_embeds_start_index)
|
||||||
input_question_len_list.append(input_question_len)
|
input_question_len_list.append(input_question_len)
|
||||||
|
|
||||||
model_outputs = self.llm(
|
model_outputs = self.llm(
|
||||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
output_hidden_states=True,
|
||||||
)
|
)
|
||||||
text_loss = model_outputs.loss
|
text_loss = model_outputs.loss
|
||||||
delay_step = 1
|
delay_step = 1
|
||||||
# prepare codec lm inputs
|
# prepare codec lm inputs
|
||||||
audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)]
|
audio_codes_lens = [
|
||||||
|
len(x) + input_question_len_list[i] + delay_step + 1
|
||||||
|
for i, x in enumerate(speech_codec_ids)
|
||||||
|
]
|
||||||
max_len_speech_codec = max(audio_codes_lens)
|
max_len_speech_codec = max(audio_codes_lens)
|
||||||
|
|
||||||
if self.codec_lm_padding_side == "right":
|
if self.codec_lm_padding_side == "right":
|
||||||
audio_codes = [
|
audio_codes = [
|
||||||
[self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
|
[self.codec_lm.config.mask_token_id]
|
||||||
|
* (input_question_len_list[i] + delay_step)
|
||||||
|
+ [self.codec_lm.config.bos_token_id]
|
||||||
|
+ x
|
||||||
|
+ [self.codec_lm.config.pad_token_id]
|
||||||
|
* (max_len_speech_codec - audio_codes_lens[i])
|
||||||
for i, x in enumerate(speech_codec_ids)
|
for i, x in enumerate(speech_codec_ids)
|
||||||
]
|
]
|
||||||
audio_labels = [
|
audio_labels = [
|
||||||
[self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
|
[self.codec_lm.config.pad_token_id]
|
||||||
|
* (input_question_len_list[i] + delay_step)
|
||||||
|
+ x
|
||||||
|
+ [self.codec_lm.config.eos_token_id]
|
||||||
|
+ [self.codec_lm.config.pad_token_id]
|
||||||
|
* (max_len_speech_codec - audio_codes_lens[i])
|
||||||
for i, x in enumerate(speech_codec_ids)
|
for i, x in enumerate(speech_codec_ids)
|
||||||
]
|
]
|
||||||
elif self.codec_lm_padding_side == "left":
|
elif self.codec_lm_padding_side == "left":
|
||||||
audio_codes = [
|
audio_codes = [
|
||||||
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x
|
[self.codec_lm.config.pad_token_id]
|
||||||
|
* (max_len_speech_codec - audio_codes_lens[i])
|
||||||
|
+ [self.codec_lm.config.mask_token_id]
|
||||||
|
* (input_question_len_list[i] + delay_step)
|
||||||
|
+ [self.codec_lm.config.bos_token_id]
|
||||||
|
+ x
|
||||||
for i, x in enumerate(speech_codec_ids)
|
for i, x in enumerate(speech_codec_ids)
|
||||||
]
|
]
|
||||||
audio_labels = [
|
audio_labels = [
|
||||||
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id]
|
[self.codec_lm.config.pad_token_id]
|
||||||
|
* (max_len_speech_codec - audio_codes_lens[i])
|
||||||
|
+ [self.codec_lm.config.pad_token_id]
|
||||||
|
* (input_question_len_list[i] + delay_step)
|
||||||
|
+ x
|
||||||
|
+ [self.codec_lm.config.eos_token_id]
|
||||||
for i, x in enumerate(speech_codec_ids)
|
for i, x in enumerate(speech_codec_ids)
|
||||||
]
|
]
|
||||||
audio_codes = torch.tensor(
|
audio_codes = torch.tensor(
|
||||||
audio_codes,
|
audio_codes, dtype=torch.int64, device=input_ids.device
|
||||||
dtype=torch.int64,
|
|
||||||
device=input_ids.device
|
|
||||||
)
|
)
|
||||||
audio_labels = torch.tensor(
|
audio_labels = torch.tensor(
|
||||||
audio_labels,
|
audio_labels, dtype=torch.int64, device=input_ids.device
|
||||||
dtype=torch.int64,
|
|
||||||
device=input_ids.device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||||
|
|
||||||
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
||||||
for i in range(len(text_label_start_index_list)):
|
for i in range(len(text_label_start_index_list)):
|
||||||
text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1]
|
text_last_hidden = model_outputs.hidden_states[-1][
|
||||||
|
i,
|
||||||
|
text_input_start_index_list[i] : text_input_start_index_list[i]
|
||||||
|
+ input_seq_len[i]
|
||||||
|
- 1,
|
||||||
|
]
|
||||||
text_last_hidden_lists.append(text_last_hidden)
|
text_last_hidden_lists.append(text_last_hidden)
|
||||||
text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos
|
text_embed = inputs_embeds[
|
||||||
|
i,
|
||||||
|
text_input_start_index_list[i]
|
||||||
|
+ 1 : text_input_start_index_list[i]
|
||||||
|
+ input_seq_len[i],
|
||||||
|
] # exclude bos
|
||||||
text_embeds_list.append(text_embed)
|
text_embeds_list.append(text_embed)
|
||||||
|
|
||||||
text_input_embeds = torch.cat(
|
text_input_embeds = torch.cat(
|
||||||
@ -344,22 +396,34 @@ class SPEECH_LLM(nn.Module):
|
|||||||
text_embed,
|
text_embed,
|
||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)# shape, T, D1 + D2
|
) # shape, T, D1 + D2
|
||||||
text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec
|
text_input_embeds = self.speech_token_projector(
|
||||||
|
text_input_embeds
|
||||||
|
) # shape, T, D_codec
|
||||||
text_input_embeds_list.append(text_input_embeds)
|
text_input_embeds_list.append(text_input_embeds)
|
||||||
|
|
||||||
for i in range(audio_embeddings.shape[0]):
|
for i in range(audio_embeddings.shape[0]):
|
||||||
text_input_embeds = text_input_embeds_list[i]
|
text_input_embeds = text_input_embeds_list[i]
|
||||||
if self.codec_lm_padding_side == "right":
|
if self.codec_lm_padding_side == "right":
|
||||||
audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds
|
audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
|
||||||
elif self.codec_lm_padding_side == "left":
|
elif self.codec_lm_padding_side == "left":
|
||||||
start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0]
|
start_idx = torch.where(
|
||||||
|
audio_codes[i] == self.codec_lm.config.mask_token_id
|
||||||
|
)[0][0]
|
||||||
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
|
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
|
||||||
assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
assert (
|
||||||
|
start_idx == start_idx_re_compute
|
||||||
|
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
||||||
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
|
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
|
||||||
text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx]
|
text_input_embeds = text_input_embeds[
|
||||||
logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}")
|
: audio_embeddings.shape[1] - start_idx
|
||||||
audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds
|
]
|
||||||
|
logging.warning(
|
||||||
|
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
|
||||||
|
)
|
||||||
|
audio_embeddings[
|
||||||
|
i, start_idx : start_idx + text_input_embeds.shape[0]
|
||||||
|
] += text_input_embeds
|
||||||
|
|
||||||
speech_outputs = self.codec_lm(
|
speech_outputs = self.codec_lm(
|
||||||
attention_mask=audio_attention_mask,
|
attention_mask=audio_attention_mask,
|
||||||
@ -369,8 +433,10 @@ class SPEECH_LLM(nn.Module):
|
|||||||
)
|
)
|
||||||
last_hidden_state = speech_outputs.hidden_states[-1].clone()
|
last_hidden_state = speech_outputs.hidden_states[-1].clone()
|
||||||
|
|
||||||
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
|
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
|
||||||
audio_logits = audio_logits.contiguous().view(-1, self.codec_lm.config.vocab_size)
|
audio_logits = audio_logits.contiguous().view(
|
||||||
|
-1, self.codec_lm.config.vocab_size
|
||||||
|
)
|
||||||
audio_labels = audio_labels.contiguous().view(-1)
|
audio_labels = audio_labels.contiguous().view(-1)
|
||||||
audio_labels = audio_labels.masked_fill(
|
audio_labels = audio_labels.masked_fill(
|
||||||
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
|
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
|
||||||
@ -378,7 +444,6 @@ class SPEECH_LLM(nn.Module):
|
|||||||
codec_loss = self.loss_fct(audio_logits, audio_labels)
|
codec_loss = self.loss_fct(audio_logits, audio_labels)
|
||||||
audio_preds = torch.argmax(audio_logits, -1)
|
audio_preds = torch.argmax(audio_logits, -1)
|
||||||
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
preds = torch.argmax(model_outputs.logits, -1)
|
preds = torch.argmax(model_outputs.logits, -1)
|
||||||
acc = compute_accuracy(
|
acc = compute_accuracy(
|
||||||
@ -392,12 +457,11 @@ class SPEECH_LLM(nn.Module):
|
|||||||
ignore_label=IGNORE_TOKEN_ID,
|
ignore_label=IGNORE_TOKEN_ID,
|
||||||
)
|
)
|
||||||
audio_topk_acc = self.audio_accuracy_metric(
|
audio_topk_acc = self.audio_accuracy_metric(
|
||||||
audio_logits.detach(),
|
audio_logits.detach(), audio_labels.detach()
|
||||||
audio_labels.detach()).item()
|
).item()
|
||||||
|
|
||||||
|
|
||||||
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
|
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
fbank: torch.Tensor = None,
|
fbank: torch.Tensor = None,
|
||||||
@ -453,12 +517,12 @@ class SPEECH_LLM(nn.Module):
|
|||||||
def decode_with_speech_output(
|
def decode_with_speech_output(
|
||||||
self,
|
self,
|
||||||
fbank: torch.Tensor = None,
|
fbank: torch.Tensor = None,
|
||||||
input_ids: torch.LongTensor = None, # Prompt input_ids
|
input_ids: torch.LongTensor = None, # Prompt input_ids
|
||||||
attention_mask: torch.Tensor = None, # Prompt attention_mask
|
attention_mask: torch.Tensor = None, # Prompt attention_mask
|
||||||
max_text_new_tokens: int = 1024,
|
max_text_new_tokens: int = 1024,
|
||||||
max_speech_new_tokens: int = 1024, # Max length for speech tokens
|
max_speech_new_tokens: int = 1024, # Max length for speech tokens
|
||||||
llm_kwargs: dict = None, # Kwargs for text LLM generate
|
llm_kwargs: dict = None, # Kwargs for text LLM generate
|
||||||
codec_lm_kwargs: dict = None # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
|
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
|
||||||
) -> Tuple[torch.LongTensor, List[List[int]]]:
|
) -> Tuple[torch.LongTensor, List[List[int]]]:
|
||||||
"""
|
"""
|
||||||
Generates text and corresponding speech tokens using the revised logic.
|
Generates text and corresponding speech tokens using the revised logic.
|
||||||
@ -479,16 +543,22 @@ class SPEECH_LLM(nn.Module):
|
|||||||
the generated speech codec tokens for a batch item.
|
the generated speech codec tokens for a batch item.
|
||||||
"""
|
"""
|
||||||
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
|
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
|
||||||
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
|
if (
|
||||||
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
|
not self.codec_lm
|
||||||
|
or not self.speech_token_projector
|
||||||
|
or not self.codec_lm_head
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"codec_lm and associated layers must be initialized to generate speech output."
|
||||||
|
)
|
||||||
|
|
||||||
device = next(self.parameters()).device # Use model's device
|
device = next(self.parameters()).device # Use model's device
|
||||||
batch_size = fbank.shape[0]
|
batch_size = fbank.shape[0]
|
||||||
|
|
||||||
# --- 1. Prepare Prompt Embeddings ---
|
# --- 1. Prepare Prompt Embeddings ---
|
||||||
encoder_outs = self.encoder(fbank)
|
encoder_outs = self.encoder(fbank)
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
|
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
|
||||||
|
|
||||||
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
|
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
@ -511,12 +581,12 @@ class SPEECH_LLM(nn.Module):
|
|||||||
"eos_token_id": self.llm.config.eos_token_id,
|
"eos_token_id": self.llm.config.eos_token_id,
|
||||||
"pad_token_id": self.llm.config.pad_token_id,
|
"pad_token_id": self.llm.config.pad_token_id,
|
||||||
"num_beams": 1,
|
"num_beams": 1,
|
||||||
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
|
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
|
||||||
"top_p": 0.5,
|
"top_p": 0.5,
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
"repetition_penalty": 1.1,
|
"repetition_penalty": 1.1,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
**(llm_kwargs or {}) # User-provided kwargs override defaults
|
**(llm_kwargs or {}), # User-provided kwargs override defaults
|
||||||
}
|
}
|
||||||
|
|
||||||
text_outputs = self.llm.generate(
|
text_outputs = self.llm.generate(
|
||||||
@ -525,17 +595,22 @@ class SPEECH_LLM(nn.Module):
|
|||||||
max_new_tokens=max_text_new_tokens,
|
max_new_tokens=max_text_new_tokens,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
**final_llm_kwargs
|
**final_llm_kwargs,
|
||||||
)
|
)
|
||||||
delay_step = 1
|
delay_step = 1
|
||||||
generated_text_ids = text_outputs.sequences # [B, S_full]
|
generated_text_ids = text_outputs.sequences # [B, S_full]
|
||||||
eos_token_id = self.llm.config.eos_token_id
|
eos_token_id = self.llm.config.eos_token_id
|
||||||
eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D
|
eos_token_embedding = self.llm.get_input_embeddings()(
|
||||||
assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
torch.tensor([[eos_token_id]], device=device)
|
||||||
|
) # 1,D
|
||||||
|
assert (
|
||||||
|
generated_text_ids[0, -1] == eos_token_id
|
||||||
|
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
||||||
thinker_token_embeds_org = [
|
thinker_token_embeds_org = [
|
||||||
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
token_hidden_states[0].to(self.llm.device)
|
||||||
|
for token_hidden_states in text_outputs.hidden_states
|
||||||
]
|
]
|
||||||
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
|
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
|
||||||
first_thinker_token_embed = torch.cat(
|
first_thinker_token_embed = torch.cat(
|
||||||
[
|
[
|
||||||
thinker_token_embeds_org[0][:, 1:],
|
thinker_token_embeds_org[0][:, 1:],
|
||||||
@ -544,19 +619,27 @@ class SPEECH_LLM(nn.Module):
|
|||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding]
|
thinker_token_embeds = (
|
||||||
|
[first_thinker_token_embed]
|
||||||
|
+ thinker_token_embeds_org[2:]
|
||||||
|
+ [eos_token_embedding]
|
||||||
|
)
|
||||||
thinker_hidden_states = [
|
thinker_hidden_states = [
|
||||||
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
token_hidden_states[-1].to(self.llm.device)
|
||||||
|
for token_hidden_states in text_outputs.hidden_states
|
||||||
]
|
]
|
||||||
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||||
thinker_reply_part = [torch.cat(
|
thinker_reply_part = [
|
||||||
[
|
torch.cat(
|
||||||
thinker_hidden_state,
|
[
|
||||||
thinker_token_embed,
|
thinker_hidden_state,
|
||||||
],
|
thinker_token_embed,
|
||||||
dim=-1,
|
],
|
||||||
)
|
dim=-1,
|
||||||
for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:])
|
)
|
||||||
|
for thinker_hidden_state, thinker_token_embed in zip(
|
||||||
|
thinker_hidden_states[1:], thinker_token_embeds[1:]
|
||||||
|
)
|
||||||
]
|
]
|
||||||
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
|
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
|
||||||
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||||
@ -568,26 +651,35 @@ class SPEECH_LLM(nn.Module):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
|
thinker_prompt_part = self.speech_token_projector(
|
||||||
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
|
thinker_prompt_part
|
||||||
|
) # [B, S_full, D_codec]
|
||||||
|
thinker_reply_part = self.speech_token_projector(
|
||||||
|
thinker_reply_part
|
||||||
|
) # [B, S_full, D_codec]
|
||||||
|
|
||||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||||
talker_input_ids = torch.full(
|
talker_input_ids = torch.full(
|
||||||
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device
|
(batch_size, thinker_prompt_part_seq_len + delay_step + 1),
|
||||||
|
self.codec_lm.config.mask_token_id,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.llm.device,
|
||||||
)
|
)
|
||||||
talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id
|
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
|
||||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
||||||
|
talker_input_ids
|
||||||
|
) # [B, S_full, D_codec]
|
||||||
thinker_input_embeds = torch.cat(
|
thinker_input_embeds = torch.cat(
|
||||||
[
|
[
|
||||||
thinker_prompt_part,
|
thinker_prompt_part,
|
||||||
thinker_reply_part[:, :delay_step + 1, :],
|
thinker_reply_part[:, : delay_step + 1, :],
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
talker_inputs_embeds += thinker_input_embeds
|
talker_inputs_embeds += thinker_input_embeds
|
||||||
thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec]
|
thinker_reply_part = thinker_reply_part[
|
||||||
|
:, delay_step + 1 :, :
|
||||||
|
] # [B, S_full, D_codec]
|
||||||
|
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
||||||
@ -599,10 +691,14 @@ class SPEECH_LLM(nn.Module):
|
|||||||
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
||||||
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
||||||
if t > 0:
|
if t > 0:
|
||||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
|
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
||||||
|
next_token_ids
|
||||||
|
) # [B, 1, D_codec]
|
||||||
if thinker_reply_part.shape[1] > 0:
|
if thinker_reply_part.shape[1] > 0:
|
||||||
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
||||||
thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step
|
thinker_reply_part = thinker_reply_part[
|
||||||
|
:, 1:, :
|
||||||
|
] # Remove the first token for next step
|
||||||
# # Add the projected text embedding corresponding to the current timestep `t`
|
# # Add the projected text embedding corresponding to the current timestep `t`
|
||||||
# if t < text_context_len:
|
# if t < text_context_len:
|
||||||
# # Text context from the full generated text sequence
|
# # Text context from the full generated text sequence
|
||||||
@ -611,20 +707,24 @@ class SPEECH_LLM(nn.Module):
|
|||||||
# else:
|
# else:
|
||||||
# # No more text context to add
|
# # No more text context to add
|
||||||
# inputs_embeds = current_speech_embeds
|
# inputs_embeds = current_speech_embeds
|
||||||
|
|
||||||
# Forward pass through codec LM for one step
|
# Forward pass through codec LM for one step
|
||||||
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
|
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
|
||||||
codec_outputs = self.codec_lm(
|
codec_outputs = self.codec_lm(
|
||||||
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
|
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
# No attention mask needed here when using past_key_values and single token input
|
# No attention mask needed here when using past_key_values and single token input
|
||||||
)
|
)
|
||||||
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here
|
last_token_hidden_state = codec_outputs.hidden_states[-1][
|
||||||
|
:, -1, :
|
||||||
|
] # [B, D_codec] #TODO: check shape here
|
||||||
# Get logits for the *last* token generated in this step
|
# Get logits for the *last* token generated in this step
|
||||||
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
next_token_logits = self.codec_lm_head(
|
||||||
|
last_token_hidden_state
|
||||||
|
) # Use -1 index
|
||||||
# suppress tokens between 4096:len(vocab)-3
|
# suppress tokens between 4096:len(vocab)-3
|
||||||
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
||||||
next_token_ids = topk_sampling(
|
next_token_ids = topk_sampling(
|
||||||
@ -634,11 +734,14 @@ class SPEECH_LLM(nn.Module):
|
|||||||
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
||||||
break
|
break
|
||||||
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
||||||
past_key_values = codec_outputs.past_key_values # Update KV cache
|
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||||
generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0])
|
generated_speech_tokens_list.append(
|
||||||
|
next_token_ids.squeeze(1).cpu().tolist()[0]
|
||||||
|
)
|
||||||
# --- 6. Return Results ---
|
# --- 6. Return Results ---
|
||||||
return generated_text_ids, generated_speech_tokens_list
|
return generated_text_ids, generated_speech_tokens_list
|
||||||
|
|
||||||
|
|
||||||
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||||
"""Calculate accuracy.
|
"""Calculate accuracy.
|
||||||
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
|
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
|
||||||
@ -717,4 +820,4 @@ def top_k_top_p_filtering(
|
|||||||
1, sorted_indices, sorted_indices_to_remove
|
1, sorted_indices, sorted_indices_to_remove
|
||||||
)
|
)
|
||||||
logits[indices_to_remove] = filter_value
|
logits[indices_to_remove] = filter_value
|
||||||
return logits
|
return logits
|
@ -1,13 +1,12 @@
|
|||||||
from typing import Callable, Dict, List, Union
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
|
||||||
|
|
||||||
from lhotse import validate
|
from lhotse import validate
|
||||||
from lhotse.cut import CutSet
|
from lhotse.cut import CutSet
|
||||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||||
from lhotse.utils import compute_num_frames, ifnone
|
from lhotse.utils import compute_num_frames, ifnone
|
||||||
from lhotse.workarounds import Hdf5MemoryIssueFix
|
from lhotse.workarounds import Hdf5MemoryIssueFix
|
||||||
|
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||||
|
|
||||||
|
|
||||||
class K2SpeechRecognitionDataset(torch.utils.data.Dataset):
|
class K2SpeechRecognitionDataset(torch.utils.data.Dataset):
|
@ -1,26 +1,25 @@
|
|||||||
# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
|
# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
|
||||||
import io
|
import io
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import gradio as gr
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
import gradio.processing_utils as processing_utils
|
|
||||||
import tempfile
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
|
||||||
from gradio_client import utils as client_utils
|
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
import whisper
|
|
||||||
import torch
|
|
||||||
from peft import LoraConfig, get_peft_model
|
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
|
||||||
from model import SPEECH_LLM, EncoderProjector
|
|
||||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
|
||||||
import sherpa_onnx
|
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.append('/workspace/CosyVoice/third_party/Matcha-TTS')
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import gradio.processing_utils as processing_utils
|
||||||
|
import numpy as np
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
import whisper
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||||
|
from gradio_client import utils as client_utils
|
||||||
|
from model import SPEECH_LLM, EncoderProjector
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||||
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
|
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
|
||||||
|
|
||||||
def get_model(params, device="cuda"):
|
def get_model(params, device="cuda"):
|
||||||
@ -88,7 +87,7 @@ def get_model(params, device="cuda"):
|
|||||||
codec_lm = AutoModelForCausalLM.from_config(
|
codec_lm = AutoModelForCausalLM.from_config(
|
||||||
config=config,
|
config=config,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch.float16
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||||
codec_lm.vocab_size = codec_vocab_size
|
codec_lm.vocab_size = codec_vocab_size
|
||||||
@ -102,12 +101,10 @@ def get_model(params, device="cuda"):
|
|||||||
llm,
|
llm,
|
||||||
encoder_projector,
|
encoder_projector,
|
||||||
codec_lm,
|
codec_lm,
|
||||||
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
codec_lm_padding_side="left" if params.use_flash_attn else "right",
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(f"{params.checkpoint_path}", map_location="cpu")
|
||||||
f"{params.checkpoint_path}", map_location="cpu"
|
|
||||||
)
|
|
||||||
model.load_state_dict(checkpoint, strict=False)
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -122,27 +119,37 @@ def audio_decode_cosyvoice(audio_tokens, codec_decoder):
|
|||||||
Args:
|
Args:
|
||||||
audio_tokens (list): List of audio tokens to be processed.
|
audio_tokens (list): List of audio tokens to be processed.
|
||||||
codec_decoder: Codec decoder for generating audio.
|
codec_decoder: Codec decoder for generating audio.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Generated audio waveform.
|
torch.Tensor: Generated audio waveform.
|
||||||
"""
|
"""
|
||||||
flow_embedding = codec_decoder.frontend.spk2info['中文女']['embedding']
|
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
|
||||||
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
|
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
|
||||||
prompt_speech_feat = torch.zeros(1, 0, 80)
|
prompt_speech_feat = torch.zeros(1, 0, 80)
|
||||||
tts_mel, _ = codec_decoder.model.flow.inference(token=audio_tokens.to(codec_decoder.model.device),
|
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
|
token=audio_tokens.to(codec_decoder.model.device),
|
||||||
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
|
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||||
prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
|
codec_decoder.model.device
|
||||||
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
|
),
|
||||||
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
|
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
|
||||||
embedding=flow_embedding.to(codec_decoder.model.device),
|
prompt_token_len=torch.tensor(
|
||||||
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),)
|
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
|
||||||
|
).to(codec_decoder.model.device),
|
||||||
|
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
|
||||||
|
prompt_feat_len=torch.tensor(
|
||||||
|
[prompt_speech_feat.shape[1]], dtype=torch.int32
|
||||||
|
).to(codec_decoder.model.device),
|
||||||
|
embedding=flow_embedding.to(codec_decoder.model.device),
|
||||||
|
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||||
audio_hat, _ = codec_decoder.model.hift.inference(speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0))
|
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
return audio_hat
|
return audio_hat
|
||||||
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
messages,
|
messages,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -178,28 +185,14 @@ def preprocess(
|
|||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
return input_ids, attention_mask
|
return input_ids, attention_mask
|
||||||
|
|
||||||
|
|
||||||
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
|
||||||
|
|
||||||
|
|
||||||
|
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||||
def format_history(history: list):
|
def format_history(history: list):
|
||||||
messages = []
|
messages = []
|
||||||
for item in history:
|
for item in history:
|
||||||
if isinstance(item["content"], str):
|
if isinstance(item["content"], str):
|
||||||
messages.append({"role": item['role'], "content": item['content']})
|
messages.append({"role": item["role"], "content": item["content"]})
|
||||||
# elif item["role"] == "user" and (isinstance(item["content"], list) or
|
|
||||||
# isinstance(item["content"], tuple)):
|
|
||||||
# file_path = item["content"][0]
|
|
||||||
# # TODO: check if the file_path's transcript is already in the history
|
|
||||||
# mime_type = client_utils.get_mimetype(file_path)
|
|
||||||
# if mime_type.startswith("audio"):
|
|
||||||
# messages.append({
|
|
||||||
# "role":
|
|
||||||
# item['role'],
|
|
||||||
# "content": item["content"][1] # append audio transcript here
|
|
||||||
# })
|
|
||||||
print('predict history: ', messages)
|
|
||||||
# messages = messages[-2:] # TODO: WAR: add history later
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
@ -217,9 +210,8 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
device = model.llm.device
|
device = model.llm.device
|
||||||
|
|
||||||
feature = feature.to(device, dtype=dtype)#.transpose(1, 2)
|
feature = feature.to(device, dtype=dtype)
|
||||||
# assert feature.shape[2] == 80
|
|
||||||
|
|
||||||
input_ids, attention_mask = preprocess([messages], tokenizer)
|
input_ids, attention_mask = preprocess([messages], tokenizer)
|
||||||
|
|
||||||
generated_ids, audio_tokens = model.decode_with_speech_output(
|
generated_ids, audio_tokens = model.decode_with_speech_output(
|
||||||
@ -227,26 +219,21 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
# print('hyps: ', hyps, 23333333333333333333333333)
|
|
||||||
yield {"type": "text", "data": hyps[0]}
|
yield {"type": "text", "data": hyps[0]}
|
||||||
# yield {"type": "text", "data": hyps}
|
|
||||||
|
|
||||||
audio_tokens = [token for token in audio_tokens if token < 4096]
|
audio_tokens = [token for token in audio_tokens if token < 4096]
|
||||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||||
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
||||||
audio = audio_hat.squeeze(0).cpu().numpy()
|
audio = audio_hat.squeeze(0).cpu().numpy()
|
||||||
# sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050)
|
|
||||||
audio = np.array(audio * 32767).astype(np.int16)
|
audio = np.array(audio * 32767).astype(np.int16)
|
||||||
# yield {"type": "audio", "data": (22050, audio)}
|
|
||||||
# with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
|
|
||||||
# sf.write(tmpfile.name, audio, 22050, format="WAV")
|
|
||||||
# audio_path = tmpfile.name
|
|
||||||
wav_io = io.BytesIO()
|
wav_io = io.BytesIO()
|
||||||
sf.write(wav_io, audio, samplerate=22050, format="WAV")
|
sf.write(wav_io, audio, samplerate=22050, format="WAV")
|
||||||
wav_io.seek(0)
|
wav_io.seek(0)
|
||||||
wav_bytes = wav_io.getvalue()
|
wav_bytes = wav_io.getvalue()
|
||||||
audio_path = processing_utils.save_bytes_to_cache(
|
audio_path = processing_utils.save_bytes_to_cache(
|
||||||
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
|
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE
|
||||||
|
)
|
||||||
|
|
||||||
yield {"type": "audio", "data": audio_path}
|
yield {"type": "audio", "data": audio_path}
|
||||||
|
|
||||||
@ -259,25 +246,27 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
|||||||
gr.update(visible=True), # stop_btn
|
gr.update(visible=True), # stop_btn
|
||||||
)
|
)
|
||||||
print(2333, history, audio)
|
print(2333, history, audio)
|
||||||
history.append({"role": "user", "content": (audio,)})
|
history.append({"role": "user", "content": (audio,)})
|
||||||
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
|
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
|
||||||
history.append({"role": "assistant", "content": ""})
|
history.append({"role": "assistant", "content": ""})
|
||||||
formatted_history = format_history(history=history) # only keep string text format
|
formatted_history = format_history(
|
||||||
|
history=history
|
||||||
|
) # only keep string text format
|
||||||
|
|
||||||
assert audio is not None
|
assert audio is not None
|
||||||
audio_transcript = get_transcript(
|
audio_transcript = get_transcript(
|
||||||
audio,
|
audio,
|
||||||
asr_model,
|
asr_model,
|
||||||
)
|
)
|
||||||
print('audio_transcript: ', audio_transcript)
|
|
||||||
history[-2]["content"] = audio_transcript
|
history[-2]["content"] = audio_transcript
|
||||||
|
|
||||||
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
|
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
|
||||||
fbank = fbank.unsqueeze(0)
|
fbank = fbank.unsqueeze(0)
|
||||||
assert fbank.ndim == 3
|
assert fbank.ndim == 3
|
||||||
|
|
||||||
# history.append({"role": "assistant", "content": ""})
|
for chunk in decode(
|
||||||
for chunk in decode(model, token2wav_model, tokenizer, fbank, formatted_history):
|
model, token2wav_model, tokenizer, fbank, formatted_history
|
||||||
|
):
|
||||||
if chunk["type"] == "text":
|
if chunk["type"] == "text":
|
||||||
history[-1]["content"] = chunk["data"]
|
history[-1]["content"] = chunk["data"]
|
||||||
yield (
|
yield (
|
||||||
@ -287,10 +276,9 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
|||||||
gr.update(visible=True), # stop_btn
|
gr.update(visible=True), # stop_btn
|
||||||
)
|
)
|
||||||
if chunk["type"] == "audio":
|
if chunk["type"] == "audio":
|
||||||
history.append({
|
history.append(
|
||||||
"role": "assistant",
|
{"role": "assistant", "content": gr.Audio(chunk["data"])}
|
||||||
"content": gr.Audio(chunk["data"])
|
)
|
||||||
})
|
|
||||||
|
|
||||||
# Final yield
|
# Final yield
|
||||||
yield (
|
yield (
|
||||||
@ -304,8 +292,7 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
|||||||
with gr.Tab("Online"):
|
with gr.Tab("Online"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
microphone = gr.Audio(sources=['microphone'],
|
microphone = gr.Audio(sources=["microphone"], type="filepath")
|
||||||
type="filepath")
|
|
||||||
submit_btn = gr.Button("Submit", variant="primary")
|
submit_btn = gr.Button("Submit", variant="primary")
|
||||||
stop_btn = gr.Button("Stop", visible=False)
|
stop_btn = gr.Button("Stop", visible=False)
|
||||||
clear_btn = gr.Button("Clear History")
|
clear_btn = gr.Button("Clear History")
|
||||||
@ -315,64 +302,80 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
|||||||
def clear_history():
|
def clear_history():
|
||||||
return [], gr.update(value=None)
|
return [], gr.update(value=None)
|
||||||
|
|
||||||
submit_event = submit_btn.click(fn=media_predict,
|
submit_event = submit_btn.click(
|
||||||
inputs=[
|
fn=media_predict,
|
||||||
microphone,
|
inputs=[
|
||||||
media_chatbot,
|
microphone,
|
||||||
],
|
media_chatbot,
|
||||||
outputs=[
|
],
|
||||||
microphone,
|
outputs=[microphone, media_chatbot, submit_btn, stop_btn],
|
||||||
media_chatbot, submit_btn,
|
)
|
||||||
stop_btn
|
|
||||||
])
|
|
||||||
stop_btn.click(
|
stop_btn.click(
|
||||||
fn=lambda:
|
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
|
||||||
(gr.update(visible=True), gr.update(visible=False)),
|
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=[submit_btn, stop_btn],
|
outputs=[submit_btn, stop_btn],
|
||||||
cancels=[submit_event],
|
cancels=[submit_event],
|
||||||
queue=False)
|
queue=False,
|
||||||
clear_btn.click(fn=clear_history,
|
)
|
||||||
inputs=None,
|
clear_btn.click(
|
||||||
outputs=[media_chatbot, microphone])
|
fn=clear_history, inputs=None, outputs=[media_chatbot, microphone]
|
||||||
|
)
|
||||||
|
|
||||||
demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100,
|
demo.queue(default_concurrency_limit=100, max_size=100).launch(
|
||||||
ssr_mode=False,
|
max_threads=100,
|
||||||
share=args.share,
|
ssr_mode=False,
|
||||||
inbrowser=args.inbrowser,
|
share=args.share,
|
||||||
server_port=args.server_port,
|
inbrowser=args.inbrowser,
|
||||||
server_name=args.server_name,)
|
server_port=args.server_port,
|
||||||
|
server_name=args.server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_args():
|
def _get_args():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument('--checkpoint-path',
|
parser.add_argument(
|
||||||
type=str,
|
"--checkpoint-path",
|
||||||
default=None,
|
type=str,
|
||||||
help='Checkpoint name or path, default to %(default)r')
|
default=None,
|
||||||
parser.add_argument('--token2wav-path',
|
help="Checkpoint name or path, default to %(default)r",
|
||||||
type=str,
|
)
|
||||||
default=None,
|
parser.add_argument(
|
||||||
help='Token2Wav path, default to %(default)r')
|
"--token2wav-path",
|
||||||
parser.add_argument('--asr-model-dir',
|
type=str,
|
||||||
type=str,
|
default=None,
|
||||||
default=None,
|
help="Token2Wav path, default to %(default)r",
|
||||||
help='ASR model dir, default to %(default)r')
|
)
|
||||||
parser.add_argument('--flash-attn2',
|
parser.add_argument(
|
||||||
action='store_true',
|
"--asr-model-dir",
|
||||||
default=False,
|
type=str,
|
||||||
help='Enable flash_attention_2 when loading the model.')
|
default=None,
|
||||||
parser.add_argument('--share',
|
help="ASR model dir, default to %(default)r",
|
||||||
action='store_true',
|
)
|
||||||
default=False,
|
parser.add_argument(
|
||||||
help='Create a publicly shareable link for the interface.')
|
"--flash-attn2",
|
||||||
parser.add_argument('--inbrowser',
|
action="store_true",
|
||||||
action='store_true',
|
default=False,
|
||||||
default=False,
|
help="Enable flash_attention_2 when loading the model.",
|
||||||
help='Automatically launch the interface in a new tab on the default browser.')
|
)
|
||||||
parser.add_argument('--server-port', type=int, default=8001, help='Demo server port.')
|
parser.add_argument(
|
||||||
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
|
"--share",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Create a publicly shareable link for the interface.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--inbrowser",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Automatically launch the interface in a new tab on the default browser.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-port", type=int, default=8001, help="Demo server port."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
|
||||||
|
)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
@ -401,6 +404,7 @@ def read_wave(wave_filename: str):
|
|||||||
|
|
||||||
return samples_float32, sample_rate
|
return samples_float32, sample_rate
|
||||||
|
|
||||||
|
|
||||||
def get_transcript(audio_path, recognizer):
|
def get_transcript(audio_path, recognizer):
|
||||||
samples, sample_rate = read_wave(audio_path)
|
samples, sample_rate = read_wave(audio_path)
|
||||||
s = recognizer.create_stream()
|
s = recognizer.create_stream()
|
||||||
@ -408,10 +412,13 @@ def get_transcript(audio_path, recognizer):
|
|||||||
recognizer.decode_streams([s])
|
recognizer.decode_streams([s])
|
||||||
return s.result.text
|
return s.result.text
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = _get_args()
|
args = _get_args()
|
||||||
model, tokenizer = get_model(args)
|
model, tokenizer = get_model(args)
|
||||||
token2wav = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False)
|
token2wav = CosyVoice(
|
||||||
|
args.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||||
|
)
|
||||||
|
|
||||||
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||||
paraformer=f"{args.asr_model_dir}/model.int8.onnx",
|
paraformer=f"{args.asr_model_dir}/model.int8.onnx",
|
||||||
@ -423,4 +430,4 @@ if __name__ == "__main__":
|
|||||||
debug=False,
|
debug=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
_launch_demo(args, model, tokenizer, token2wav, asr_model)
|
_launch_demo(args, model, tokenizer, token2wav, asr_model)
|
Loading…
x
Reference in New Issue
Block a user