update README

This commit is contained in:
root 2025-04-29 08:49:12 +00:00
parent 448a4eeea7
commit 360f0aa397
12 changed files with 423 additions and 217 deletions

View File

@ -0,0 +1,89 @@
# Introduction
This recipe includes scripts for training speech2speech models.
# SPEECH2SPEECH
The following table lists the folders for different tasks.
|Recipe | Speech Input | Speech Output | Comment|
|--------------|--------------|---------------|--------|
|Qwen-omni like| Continuous Embeddins| Cosyvoice1 50Hz Single-codebook Token | Text-driven; using Thinker LLM for text token, small Talker LLM for speech token |
### [Qwen-omni like Speech2speech Recipe](./qwen_omni)
[Qwen2.5-Omni](https://github.com/QwenLM/Qwen2.5-Omni) style model using [worstchan/Belle_1.4M-SLAM-Omni](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni) dataset.
<br>
<p align="center">
<img src="assets/framework.jpg" width="800"/>
<p>
<br>
Command for training is:
```bash
pip install -r whisper_llm_zh/requirements.txt
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
# First, we only train the projector and freeze other modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora False --unfreeze-llm False
# Then we jointly train the projector and LLM LoRA modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
```
Command for decoding:
```bash
mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 331 KiB

View File

@ -17,6 +17,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Usage:
python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank \
--huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
--audio-key question_audio --text-key answer \
--prefix ultrachat
"""
import argparse import argparse
import logging import logging
@ -126,7 +136,7 @@ def compute_fbank(args):
num_digits = 5 num_digits = 5
for i in range(num_shards): for i in range(num_shards):
shard = dataset.shard(num_shards, i) shard = dataset.shard(num_shards, i)
shard = shard.take(10) # for testing # shard = shard.take(10) # for testing
logging.info( logging.info(
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}" f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
) )
@ -159,8 +169,6 @@ def compute_fbank(args):
logging.info(f"Saving to {cuts_path}") logging.info(f"Saving to {cuts_path}")
# see https://github.com/lhotse-speech/lhotse/issues/1125 # see https://github.com/lhotse-speech/lhotse/issues/1125
cut_set.drop_recordings().to_file(cuts_path) cut_set.drop_recordings().to_file(cuts_path)
if i > 1:
break
def main(): def main():

View File

@ -1,11 +1,14 @@
from typing import List, Tuple # Added for type hints
import torch import torch
from torch import nn from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
from typing import List, Tuple # Added for type hints
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import MulticlassAccuracy
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import logging import logging
class EncoderProjector(nn.Module): class EncoderProjector(nn.Module):
""" """
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model. The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
@ -69,7 +72,8 @@ class SPEECH_LLM(nn.Module):
self.codec_lm = codec_lm self.codec_lm = codec_lm
if self.codec_lm: if self.codec_lm:
self.speech_token_projector = nn.Linear( self.speech_token_projector = nn.Linear(
self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size self.llm.config.hidden_size + self.llm.config.hidden_size,
self.codec_lm.config.hidden_size,
) )
self.codec_lm_head = nn.Linear( self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
@ -89,6 +93,7 @@ class SPEECH_LLM(nn.Module):
multidim_average="global", multidim_average="global",
ignore_index=IGNORE_TOKEN_ID, ignore_index=IGNORE_TOKEN_ID,
) )
def _merge_input_ids_with_speech_features( def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
): ):
@ -274,68 +279,115 @@ class SPEECH_LLM(nn.Module):
) = self._merge_input_ids_with_speech_features( ) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels speech_features, inputs_embeds, input_ids, attention_mask, labels
) )
input_seq_len = attention_mask.sum(dim=1) # shape, B input_seq_len = attention_mask.sum(dim=1) # shape, B
text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], [] (
text_label_start_index_list,
text_input_start_index_list,
input_question_len_list,
) = ([], [], [])
for i in range(labels.shape[0]): for i in range(labels.shape[0]):
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0] input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
input_embeds_start_index = input_embeds_valid_index[0] input_embeds_start_index = input_embeds_valid_index[0]
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0] text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
text_labels_start_index = text_labels_valid_index[0] text_labels_start_index = text_labels_valid_index[0]
assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}" assert (
assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}" input_seq_len[i]
== input_embeds_valid_index[-1] - input_embeds_start_index + 1
), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
assert (
input_embeds_valid_index[-1] == text_labels_valid_index[-1]
), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
input_question_len = text_labels_start_index - input_embeds_start_index input_question_len = text_labels_start_index - input_embeds_start_index
assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i] assert (
input_question_len
+ text_labels_valid_index[-1]
- text_labels_start_index
+ 1
== input_seq_len[i]
)
text_label_start_index_list.append(text_labels_start_index) text_label_start_index_list.append(text_labels_start_index)
text_input_start_index_list.append(input_embeds_start_index) text_input_start_index_list.append(input_embeds_start_index)
input_question_len_list.append(input_question_len) input_question_len_list.append(input_question_len)
model_outputs = self.llm( model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
) )
text_loss = model_outputs.loss text_loss = model_outputs.loss
delay_step = 1 delay_step = 1
# prepare codec lm inputs # prepare codec lm inputs
audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)] audio_codes_lens = [
len(x) + input_question_len_list[i] + delay_step + 1
for i, x in enumerate(speech_codec_ids)
]
max_len_speech_codec = max(audio_codes_lens) max_len_speech_codec = max(audio_codes_lens)
if self.codec_lm_padding_side == "right": if self.codec_lm_padding_side == "right":
audio_codes = [ audio_codes = [
[self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) [self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids) for i, x in enumerate(speech_codec_ids)
] ]
audio_labels = [ audio_labels = [
[self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) [self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
+ [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids) for i, x in enumerate(speech_codec_ids)
] ]
elif self.codec_lm_padding_side == "left": elif self.codec_lm_padding_side == "left":
audio_codes = [ audio_codes = [
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.mask_token_id]
* (input_question_len_list[i] + delay_step)
+ [self.codec_lm.config.bos_token_id]
+ x
for i, x in enumerate(speech_codec_ids) for i, x in enumerate(speech_codec_ids)
] ]
audio_labels = [ audio_labels = [
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] [self.codec_lm.config.pad_token_id]
* (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.pad_token_id]
* (input_question_len_list[i] + delay_step)
+ x
+ [self.codec_lm.config.eos_token_id]
for i, x in enumerate(speech_codec_ids) for i, x in enumerate(speech_codec_ids)
] ]
audio_codes = torch.tensor( audio_codes = torch.tensor(
audio_codes, audio_codes, dtype=torch.int64, device=input_ids.device
dtype=torch.int64,
device=input_ids.device
) )
audio_labels = torch.tensor( audio_labels = torch.tensor(
audio_labels, audio_labels, dtype=torch.int64, device=input_ids.device
dtype=torch.int64,
device=input_ids.device
) )
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], [] text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
for i in range(len(text_label_start_index_list)): for i in range(len(text_label_start_index_list)):
text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1] text_last_hidden = model_outputs.hidden_states[-1][
i,
text_input_start_index_list[i] : text_input_start_index_list[i]
+ input_seq_len[i]
- 1,
]
text_last_hidden_lists.append(text_last_hidden) text_last_hidden_lists.append(text_last_hidden)
text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos text_embed = inputs_embeds[
i,
text_input_start_index_list[i]
+ 1 : text_input_start_index_list[i]
+ input_seq_len[i],
] # exclude bos
text_embeds_list.append(text_embed) text_embeds_list.append(text_embed)
text_input_embeds = torch.cat( text_input_embeds = torch.cat(
@ -344,22 +396,34 @@ class SPEECH_LLM(nn.Module):
text_embed, text_embed,
], ],
dim=-1, dim=-1,
)# shape, T, D1 + D2 ) # shape, T, D1 + D2
text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec text_input_embeds = self.speech_token_projector(
text_input_embeds
) # shape, T, D_codec
text_input_embeds_list.append(text_input_embeds) text_input_embeds_list.append(text_input_embeds)
for i in range(audio_embeddings.shape[0]): for i in range(audio_embeddings.shape[0]):
text_input_embeds = text_input_embeds_list[i] text_input_embeds = text_input_embeds_list[i]
if self.codec_lm_padding_side == "right": if self.codec_lm_padding_side == "right":
audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
elif self.codec_lm_padding_side == "left": elif self.codec_lm_padding_side == "left":
start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0] start_idx = torch.where(
audio_codes[i] == self.codec_lm.config.mask_token_id
)[0][0]
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0] start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}" assert (
start_idx == start_idx_re_compute
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx: if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx] text_input_embeds = text_input_embeds[
logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}") : audio_embeddings.shape[1] - start_idx
audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds ]
logging.warning(
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
)
audio_embeddings[
i, start_idx : start_idx + text_input_embeds.shape[0]
] += text_input_embeds
speech_outputs = self.codec_lm( speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask, attention_mask=audio_attention_mask,
@ -369,8 +433,10 @@ class SPEECH_LLM(nn.Module):
) )
last_hidden_state = speech_outputs.hidden_states[-1].clone() last_hidden_state = speech_outputs.hidden_states[-1].clone()
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
audio_logits = audio_logits.contiguous().view(-1, self.codec_lm.config.vocab_size) audio_logits = audio_logits.contiguous().view(
-1, self.codec_lm.config.vocab_size
)
audio_labels = audio_labels.contiguous().view(-1) audio_labels = audio_labels.contiguous().view(-1)
audio_labels = audio_labels.masked_fill( audio_labels = audio_labels.masked_fill(
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
@ -378,7 +444,6 @@ class SPEECH_LLM(nn.Module):
codec_loss = self.loss_fct(audio_logits, audio_labels) codec_loss = self.loss_fct(audio_logits, audio_labels)
audio_preds = torch.argmax(audio_logits, -1) audio_preds = torch.argmax(audio_logits, -1)
with torch.no_grad(): with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1) preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy( acc = compute_accuracy(
@ -392,12 +457,11 @@ class SPEECH_LLM(nn.Module):
ignore_label=IGNORE_TOKEN_ID, ignore_label=IGNORE_TOKEN_ID,
) )
audio_topk_acc = self.audio_accuracy_metric( audio_topk_acc = self.audio_accuracy_metric(
audio_logits.detach(), audio_logits.detach(), audio_labels.detach()
audio_labels.detach()).item() ).item()
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
def decode( def decode(
self, self,
fbank: torch.Tensor = None, fbank: torch.Tensor = None,
@ -453,12 +517,12 @@ class SPEECH_LLM(nn.Module):
def decode_with_speech_output( def decode_with_speech_output(
self, self,
fbank: torch.Tensor = None, fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None, # Prompt input_ids input_ids: torch.LongTensor = None, # Prompt input_ids
attention_mask: torch.Tensor = None, # Prompt attention_mask attention_mask: torch.Tensor = None, # Prompt attention_mask
max_text_new_tokens: int = 1024, max_text_new_tokens: int = 1024,
max_speech_new_tokens: int = 1024, # Max length for speech tokens max_speech_new_tokens: int = 1024, # Max length for speech tokens
llm_kwargs: dict = None, # Kwargs for text LLM generate llm_kwargs: dict = None, # Kwargs for text LLM generate
codec_lm_kwargs: dict = None # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
) -> Tuple[torch.LongTensor, List[List[int]]]: ) -> Tuple[torch.LongTensor, List[List[int]]]:
""" """
Generates text and corresponding speech tokens using the revised logic. Generates text and corresponding speech tokens using the revised logic.
@ -479,16 +543,22 @@ class SPEECH_LLM(nn.Module):
the generated speech codec tokens for a batch item. the generated speech codec tokens for a batch item.
""" """
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation." assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head: if (
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.") not self.codec_lm
or not self.speech_token_projector
or not self.codec_lm_head
):
raise ValueError(
"codec_lm and associated layers must be initialized to generate speech output."
)
device = next(self.parameters()).device # Use model's device device = next(self.parameters()).device # Use model's device
batch_size = fbank.shape[0] batch_size = fbank.shape[0]
# --- 1. Prepare Prompt Embeddings --- # --- 1. Prepare Prompt Embeddings ---
encoder_outs = self.encoder(fbank) encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs) speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
prompt_embeds = self.llm.get_input_embeddings()(input_ids) prompt_embeds = self.llm.get_input_embeddings()(input_ids)
@ -511,12 +581,12 @@ class SPEECH_LLM(nn.Module):
"eos_token_id": self.llm.config.eos_token_id, "eos_token_id": self.llm.config.eos_token_id,
"pad_token_id": self.llm.config.pad_token_id, "pad_token_id": self.llm.config.pad_token_id,
"num_beams": 1, "num_beams": 1,
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed "do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
"top_p": 0.5, "top_p": 0.5,
"top_k": 20, "top_k": 20,
"repetition_penalty": 1.1, "repetition_penalty": 1.1,
"temperature": 0.7, "temperature": 0.7,
**(llm_kwargs or {}) # User-provided kwargs override defaults **(llm_kwargs or {}), # User-provided kwargs override defaults
} }
text_outputs = self.llm.generate( text_outputs = self.llm.generate(
@ -525,17 +595,22 @@ class SPEECH_LLM(nn.Module):
max_new_tokens=max_text_new_tokens, max_new_tokens=max_text_new_tokens,
return_dict_in_generate=True, return_dict_in_generate=True,
output_hidden_states=True, output_hidden_states=True,
**final_llm_kwargs **final_llm_kwargs,
) )
delay_step = 1 delay_step = 1
generated_text_ids = text_outputs.sequences # [B, S_full] generated_text_ids = text_outputs.sequences # [B, S_full]
eos_token_id = self.llm.config.eos_token_id eos_token_id = self.llm.config.eos_token_id
eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D eos_token_embedding = self.llm.get_input_embeddings()(
assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}" torch.tensor([[eos_token_id]], device=device)
) # 1,D
assert (
generated_text_ids[0, -1] == eos_token_id
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
thinker_token_embeds_org = [ thinker_token_embeds_org = [
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states token_hidden_states[0].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
] ]
# shift one for thinker token_embeds, drop the first embeds, and add the eos token # shift one for thinker token_embeds, drop the first embeds, and add the eos token
first_thinker_token_embed = torch.cat( first_thinker_token_embed = torch.cat(
[ [
thinker_token_embeds_org[0][:, 1:], thinker_token_embeds_org[0][:, 1:],
@ -544,19 +619,27 @@ class SPEECH_LLM(nn.Module):
dim=1, dim=1,
) )
thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding] thinker_token_embeds = (
[first_thinker_token_embed]
+ thinker_token_embeds_org[2:]
+ [eos_token_embedding]
)
thinker_hidden_states = [ thinker_hidden_states = [
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states token_hidden_states[-1].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
] ]
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) # thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
thinker_reply_part = [torch.cat( thinker_reply_part = [
[ torch.cat(
thinker_hidden_state, [
thinker_token_embed, thinker_hidden_state,
], thinker_token_embed,
dim=-1, ],
) dim=-1,
for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:]) )
for thinker_hidden_state, thinker_token_embed in zip(
thinker_hidden_states[1:], thinker_token_embeds[1:]
)
] ]
thinker_reply_part = torch.cat(thinker_reply_part, dim=1) thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0] # thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
@ -568,26 +651,35 @@ class SPEECH_LLM(nn.Module):
dim=-1, dim=-1,
) )
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec] thinker_prompt_part = self.speech_token_projector(
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec] thinker_prompt_part
) # [B, S_full, D_codec]
thinker_reply_part = self.speech_token_projector(
thinker_reply_part
) # [B, S_full, D_codec]
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1] thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full( talker_input_ids = torch.full(
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device (batch_size, thinker_prompt_part_seq_len + delay_step + 1),
self.codec_lm.config.mask_token_id,
dtype=torch.long,
device=self.llm.device,
) )
talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec] talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
talker_input_ids
) # [B, S_full, D_codec]
thinker_input_embeds = torch.cat( thinker_input_embeds = torch.cat(
[ [
thinker_prompt_part, thinker_prompt_part,
thinker_reply_part[:, :delay_step + 1, :], thinker_reply_part[:, : delay_step + 1, :],
], ],
dim=1, dim=1,
) )
talker_inputs_embeds += thinker_input_embeds talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec] thinker_reply_part = thinker_reply_part[
:, delay_step + 1 :, :
] # [B, S_full, D_codec]
past_key_values = None past_key_values = None
# generated_speech_tokens_list = [[] for _ in range(batch_size)] # generated_speech_tokens_list = [[] for _ in range(batch_size)]
@ -599,10 +691,14 @@ class SPEECH_LLM(nn.Module):
# Get embedding for the *current* input token ID (initially BOS, then generated tokens) # Get embedding for the *current* input token ID (initially BOS, then generated tokens)
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec] # current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
if t > 0: if t > 0:
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec] talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
next_token_ids
) # [B, 1, D_codec]
if thinker_reply_part.shape[1] > 0: if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :] talker_inputs_embeds += thinker_reply_part[:, :1, :]
thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step thinker_reply_part = thinker_reply_part[
:, 1:, :
] # Remove the first token for next step
# # Add the projected text embedding corresponding to the current timestep `t` # # Add the projected text embedding corresponding to the current timestep `t`
# if t < text_context_len: # if t < text_context_len:
# # Text context from the full generated text sequence # # Text context from the full generated text sequence
@ -611,20 +707,24 @@ class SPEECH_LLM(nn.Module):
# else: # else:
# # No more text context to add # # No more text context to add
# inputs_embeds = current_speech_embeds # inputs_embeds = current_speech_embeds
# Forward pass through codec LM for one step # Forward pass through codec LM for one step
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation # We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
codec_outputs = self.codec_lm( codec_outputs = self.codec_lm(
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
return_dict=True, return_dict=True,
output_hidden_states=True, output_hidden_states=True,
# No attention mask needed here when using past_key_values and single token input # No attention mask needed here when using past_key_values and single token input
) )
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here last_token_hidden_state = codec_outputs.hidden_states[-1][
:, -1, :
] # [B, D_codec] #TODO: check shape here
# Get logits for the *last* token generated in this step # Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index next_token_logits = self.codec_lm_head(
last_token_hidden_state
) # Use -1 index
# suppress tokens between 4096:len(vocab)-3 # suppress tokens between 4096:len(vocab)-3
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens? # next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
next_token_ids = topk_sampling( next_token_ids = topk_sampling(
@ -634,11 +734,14 @@ class SPEECH_LLM(nn.Module):
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id: if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
break break
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step # current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
past_key_values = codec_outputs.past_key_values # Update KV cache past_key_values = codec_outputs.past_key_values # Update KV cache
generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0]) generated_speech_tokens_list.append(
next_token_ids.squeeze(1).cpu().tolist()[0]
)
# --- 6. Return Results --- # --- 6. Return Results ---
return generated_text_ids, generated_speech_tokens_list return generated_text_ids, generated_speech_tokens_list
def compute_accuracy(pad_outputs, pad_targets, ignore_label): def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy. """Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
@ -717,4 +820,4 @@ def top_k_top_p_filtering(
1, sorted_indices, sorted_indices_to_remove 1, sorted_indices, sorted_indices_to_remove
) )
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
return logits return logits

View File

@ -1,13 +1,12 @@
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
import torch import torch
from torch.utils.data.dataloader import DataLoader, default_collate
from lhotse import validate from lhotse import validate
from lhotse.cut import CutSet from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone from lhotse.utils import compute_num_frames, ifnone
from lhotse.workarounds import Hdf5MemoryIssueFix from lhotse.workarounds import Hdf5MemoryIssueFix
from torch.utils.data.dataloader import DataLoader, default_collate
class K2SpeechRecognitionDataset(torch.utils.data.Dataset): class K2SpeechRecognitionDataset(torch.utils.data.Dataset):

View File

@ -1,26 +1,25 @@
# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py # Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
import io import io
import numpy as np
import gradio as gr
import soundfile as sf
import gradio.processing_utils as processing_utils
import tempfile
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from gradio_client import utils as client_utils
from argparse import ArgumentParser
import whisper
import torch
from peft import LoraConfig, get_peft_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from model import SPEECH_LLM, EncoderProjector
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
import sherpa_onnx
from cosyvoice.cli.cosyvoice import CosyVoice
import sys import sys
sys.path.append('/workspace/CosyVoice/third_party/Matcha-TTS') from argparse import ArgumentParser
import gradio as gr
import gradio.processing_utils as processing_utils
import numpy as np
import sherpa_onnx
import soundfile as sf
import torch
import whisper
from cosyvoice.cli.cosyvoice import CosyVoice
from gradio_client import utils as client_utils
from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def get_model(params, device="cuda"): def get_model(params, device="cuda"):
@ -88,7 +87,7 @@ def get_model(params, device="cuda"):
codec_lm = AutoModelForCausalLM.from_config( codec_lm = AutoModelForCausalLM.from_config(
config=config, config=config,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
torch_dtype=torch.float16 torch_dtype=torch.float16,
) )
codec_lm.resize_token_embeddings(codec_vocab_size) codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size codec_lm.vocab_size = codec_vocab_size
@ -102,12 +101,10 @@ def get_model(params, device="cuda"):
llm, llm,
encoder_projector, encoder_projector,
codec_lm, codec_lm,
codec_lm_padding_side= "left" if params.use_flash_attn else "right", codec_lm_padding_side="left" if params.use_flash_attn else "right",
) )
checkpoint = torch.load( checkpoint = torch.load(f"{params.checkpoint_path}", map_location="cpu")
f"{params.checkpoint_path}", map_location="cpu"
)
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(checkpoint, strict=False)
model.to(device) model.to(device)
@ -122,27 +119,37 @@ def audio_decode_cosyvoice(audio_tokens, codec_decoder):
Args: Args:
audio_tokens (list): List of audio tokens to be processed. audio_tokens (list): List of audio tokens to be processed.
codec_decoder: Codec decoder for generating audio. codec_decoder: Codec decoder for generating audio.
Returns: Returns:
torch.Tensor: Generated audio waveform. torch.Tensor: Generated audio waveform.
""" """
flow_embedding = codec_decoder.frontend.spk2info['中文女']['embedding'] flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32) flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
prompt_speech_feat = torch.zeros(1, 0, 80) prompt_speech_feat = torch.zeros(1, 0, 80)
tts_mel, _ = codec_decoder.model.flow.inference(token=audio_tokens.to(codec_decoder.model.device), tts_mel, _ = codec_decoder.model.flow.inference(
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(codec_decoder.model.device), token=audio_tokens.to(codec_decoder.model.device),
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device), token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(codec_decoder.model.device), codec_decoder.model.device
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device), ),
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(codec_decoder.model.device), prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
embedding=flow_embedding.to(codec_decoder.model.device), prompt_token_len=torch.tensor(
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),) [flow_prompt_speech_token.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
prompt_feat_len=torch.tensor(
[prompt_speech_feat.shape[1]], dtype=torch.int32
).to(codec_decoder.model.device),
embedding=flow_embedding.to(codec_decoder.model.device),
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
)
audio_hat, _ = codec_decoder.model.hift.inference(
audio_hat, _ = codec_decoder.model.hift.inference(speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)) speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat return audio_hat
def preprocess( def preprocess(
messages, messages,
tokenizer, tokenizer,
@ -178,28 +185,14 @@ def preprocess(
attention_mask = input_ids.ne(tokenizer.pad_token_id) attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask return input_ids, attention_mask
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def format_history(history: list): def format_history(history: list):
messages = [] messages = []
for item in history: for item in history:
if isinstance(item["content"], str): if isinstance(item["content"], str):
messages.append({"role": item['role'], "content": item['content']}) messages.append({"role": item["role"], "content": item["content"]})
# elif item["role"] == "user" and (isinstance(item["content"], list) or
# isinstance(item["content"], tuple)):
# file_path = item["content"][0]
# # TODO: check if the file_path's transcript is already in the history
# mime_type = client_utils.get_mimetype(file_path)
# if mime_type.startswith("audio"):
# messages.append({
# "role":
# item['role'],
# "content": item["content"][1] # append audio transcript here
# })
print('predict history: ', messages)
# messages = messages[-2:] # TODO: WAR: add history later
return messages return messages
def decode( def decode(
@ -217,9 +210,8 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
dtype = torch.float32 dtype = torch.float32
device = model.llm.device device = model.llm.device
feature = feature.to(device, dtype=dtype)#.transpose(1, 2) feature = feature.to(device, dtype=dtype)
# assert feature.shape[2] == 80
input_ids, attention_mask = preprocess([messages], tokenizer) input_ids, attention_mask = preprocess([messages], tokenizer)
generated_ids, audio_tokens = model.decode_with_speech_output( generated_ids, audio_tokens = model.decode_with_speech_output(
@ -227,26 +219,21 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
) )
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# print('hyps: ', hyps, 23333333333333333333333333)
yield {"type": "text", "data": hyps[0]} yield {"type": "text", "data": hyps[0]}
# yield {"type": "text", "data": hyps}
audio_tokens = [token for token in audio_tokens if token < 4096] audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model) audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
audio = audio_hat.squeeze(0).cpu().numpy() audio = audio_hat.squeeze(0).cpu().numpy()
# sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050)
audio = np.array(audio * 32767).astype(np.int16) audio = np.array(audio * 32767).astype(np.int16)
# yield {"type": "audio", "data": (22050, audio)}
# with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
# sf.write(tmpfile.name, audio, 22050, format="WAV")
# audio_path = tmpfile.name
wav_io = io.BytesIO() wav_io = io.BytesIO()
sf.write(wav_io, audio, samplerate=22050, format="WAV") sf.write(wav_io, audio, samplerate=22050, format="WAV")
wav_io.seek(0) wav_io.seek(0)
wav_bytes = wav_io.getvalue() wav_bytes = wav_io.getvalue()
audio_path = processing_utils.save_bytes_to_cache( audio_path = processing_utils.save_bytes_to_cache(
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE) wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE
)
yield {"type": "audio", "data": audio_path} yield {"type": "audio", "data": audio_path}
@ -259,25 +246,27 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
gr.update(visible=True), # stop_btn gr.update(visible=True), # stop_btn
) )
print(2333, history, audio) print(2333, history, audio)
history.append({"role": "user", "content": (audio,)}) history.append({"role": "user", "content": (audio,)})
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}) history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
history.append({"role": "assistant", "content": ""}) history.append({"role": "assistant", "content": ""})
formatted_history = format_history(history=history) # only keep string text format formatted_history = format_history(
history=history
) # only keep string text format
assert audio is not None assert audio is not None
audio_transcript = get_transcript( audio_transcript = get_transcript(
audio, audio,
asr_model, asr_model,
) )
print('audio_transcript: ', audio_transcript)
history[-2]["content"] = audio_transcript history[-2]["content"] = audio_transcript
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device) fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
fbank = fbank.unsqueeze(0) fbank = fbank.unsqueeze(0)
assert fbank.ndim == 3 assert fbank.ndim == 3
# history.append({"role": "assistant", "content": ""}) for chunk in decode(
for chunk in decode(model, token2wav_model, tokenizer, fbank, formatted_history): model, token2wav_model, tokenizer, fbank, formatted_history
):
if chunk["type"] == "text": if chunk["type"] == "text":
history[-1]["content"] = chunk["data"] history[-1]["content"] = chunk["data"]
yield ( yield (
@ -287,10 +276,9 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
gr.update(visible=True), # stop_btn gr.update(visible=True), # stop_btn
) )
if chunk["type"] == "audio": if chunk["type"] == "audio":
history.append({ history.append(
"role": "assistant", {"role": "assistant", "content": gr.Audio(chunk["data"])}
"content": gr.Audio(chunk["data"]) )
})
# Final yield # Final yield
yield ( yield (
@ -304,8 +292,7 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
with gr.Tab("Online"): with gr.Tab("Online"):
with gr.Row(): with gr.Row():
with gr.Column(scale=1): with gr.Column(scale=1):
microphone = gr.Audio(sources=['microphone'], microphone = gr.Audio(sources=["microphone"], type="filepath")
type="filepath")
submit_btn = gr.Button("Submit", variant="primary") submit_btn = gr.Button("Submit", variant="primary")
stop_btn = gr.Button("Stop", visible=False) stop_btn = gr.Button("Stop", visible=False)
clear_btn = gr.Button("Clear History") clear_btn = gr.Button("Clear History")
@ -315,64 +302,80 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def clear_history(): def clear_history():
return [], gr.update(value=None) return [], gr.update(value=None)
submit_event = submit_btn.click(fn=media_predict, submit_event = submit_btn.click(
inputs=[ fn=media_predict,
microphone, inputs=[
media_chatbot, microphone,
], media_chatbot,
outputs=[ ],
microphone, outputs=[microphone, media_chatbot, submit_btn, stop_btn],
media_chatbot, submit_btn, )
stop_btn
])
stop_btn.click( stop_btn.click(
fn=lambda: fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
(gr.update(visible=True), gr.update(visible=False)),
inputs=None, inputs=None,
outputs=[submit_btn, stop_btn], outputs=[submit_btn, stop_btn],
cancels=[submit_event], cancels=[submit_event],
queue=False) queue=False,
clear_btn.click(fn=clear_history, )
inputs=None, clear_btn.click(
outputs=[media_chatbot, microphone]) fn=clear_history, inputs=None, outputs=[media_chatbot, microphone]
)
demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100, demo.queue(default_concurrency_limit=100, max_size=100).launch(
ssr_mode=False, max_threads=100,
share=args.share, ssr_mode=False,
inbrowser=args.inbrowser, share=args.share,
server_port=args.server_port, inbrowser=args.inbrowser,
server_name=args.server_name,) server_port=args.server_port,
server_name=args.server_name,
)
def _get_args(): def _get_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--checkpoint-path', parser.add_argument(
type=str, "--checkpoint-path",
default=None, type=str,
help='Checkpoint name or path, default to %(default)r') default=None,
parser.add_argument('--token2wav-path', help="Checkpoint name or path, default to %(default)r",
type=str, )
default=None, parser.add_argument(
help='Token2Wav path, default to %(default)r') "--token2wav-path",
parser.add_argument('--asr-model-dir', type=str,
type=str, default=None,
default=None, help="Token2Wav path, default to %(default)r",
help='ASR model dir, default to %(default)r') )
parser.add_argument('--flash-attn2', parser.add_argument(
action='store_true', "--asr-model-dir",
default=False, type=str,
help='Enable flash_attention_2 when loading the model.') default=None,
parser.add_argument('--share', help="ASR model dir, default to %(default)r",
action='store_true', )
default=False, parser.add_argument(
help='Create a publicly shareable link for the interface.') "--flash-attn2",
parser.add_argument('--inbrowser', action="store_true",
action='store_true', default=False,
default=False, help="Enable flash_attention_2 when loading the model.",
help='Automatically launch the interface in a new tab on the default browser.') )
parser.add_argument('--server-port', type=int, default=8001, help='Demo server port.') parser.add_argument(
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.') "--share",
action="store_true",
default=False,
help="Create a publicly shareable link for the interface.",
)
parser.add_argument(
"--inbrowser",
action="store_true",
default=False,
help="Automatically launch the interface in a new tab on the default browser.",
)
parser.add_argument(
"--server-port", type=int, default=8001, help="Demo server port."
)
parser.add_argument(
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
)
add_model_arguments(parser) add_model_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -401,6 +404,7 @@ def read_wave(wave_filename: str):
return samples_float32, sample_rate return samples_float32, sample_rate
def get_transcript(audio_path, recognizer): def get_transcript(audio_path, recognizer):
samples, sample_rate = read_wave(audio_path) samples, sample_rate = read_wave(audio_path)
s = recognizer.create_stream() s = recognizer.create_stream()
@ -408,10 +412,13 @@ def get_transcript(audio_path, recognizer):
recognizer.decode_streams([s]) recognizer.decode_streams([s])
return s.result.text return s.result.text
if __name__ == "__main__": if __name__ == "__main__":
args = _get_args() args = _get_args()
model, tokenizer = get_model(args) model, tokenizer = get_model(args)
token2wav = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False) token2wav = CosyVoice(
args.token2wav_path, load_jit=False, load_trt=False, fp16=False
)
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer( asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=f"{args.asr_model_dir}/model.int8.onnx", paraformer=f"{args.asr_model_dir}/model.int8.onnx",
@ -423,4 +430,4 @@ if __name__ == "__main__":
debug=False, debug=False,
) )
_launch_demo(args, model, tokenizer, token2wav, asr_model) _launch_demo(args, model, tokenizer, token2wav, asr_model)