mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
update README
This commit is contained in:
parent
448a4eeea7
commit
360f0aa397
89
egs/speech_llm/SPEECH2SPEECH/README.md
Normal file
89
egs/speech_llm/SPEECH2SPEECH/README.md
Normal file
@ -0,0 +1,89 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes scripts for training speech2speech models.
|
||||
|
||||
# SPEECH2SPEECH
|
||||
|
||||
The following table lists the folders for different tasks.
|
||||
|
||||
|Recipe | Speech Input | Speech Output | Comment|
|
||||
|--------------|--------------|---------------|--------|
|
||||
|Qwen-omni like| Continuous Embeddins| Cosyvoice1 50Hz Single-codebook Token | Text-driven; using Thinker LLM for text token, small Talker LLM for speech token |
|
||||
|
||||
### [Qwen-omni like Speech2speech Recipe](./qwen_omni)
|
||||
|
||||
[Qwen2.5-Omni](https://github.com/QwenLM/Qwen2.5-Omni) style model using [worstchan/Belle_1.4M-SLAM-Omni](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni) dataset.
|
||||
|
||||
<br>
|
||||
<p align="center">
|
||||
<img src="assets/framework.jpg" width="800"/>
|
||||
<p>
|
||||
<br>
|
||||
|
||||
Command for training is:
|
||||
```bash
|
||||
pip install -r whisper_llm_zh/requirements.txt
|
||||
|
||||
pip install huggingface_hub['cli']
|
||||
mkdir -p models/whisper models/qwen
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
||||
|
||||
# First, we only train the projector and freeze other modules.
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora False --unfreeze-llm False
|
||||
|
||||
# Then we jointly train the projector and LLM LoRA modules.
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True
|
||||
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
|
||||
```
|
||||
|
||||
Command for decoding:
|
||||
```bash
|
||||
mkdir -p models/whisper models/qwen models/checkpoint
|
||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
|
||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||
|
||||
python3 ./whisper_llm_zh/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name models/qwen \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --dataset aishell
|
||||
```
|
BIN
egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg
Normal file
BIN
egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 331 KiB |
@ -17,6 +17,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
--out-dir data/fbank \
|
||||
--huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
|
||||
--audio-key question_audio --text-key answer \
|
||||
--prefix ultrachat
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -126,7 +136,7 @@ def compute_fbank(args):
|
||||
num_digits = 5
|
||||
for i in range(num_shards):
|
||||
shard = dataset.shard(num_shards, i)
|
||||
shard = shard.take(10) # for testing
|
||||
# shard = shard.take(10) # for testing
|
||||
logging.info(
|
||||
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
|
||||
)
|
||||
@ -159,8 +169,6 @@ def compute_fbank(args):
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
||||
cut_set.drop_recordings().to_file(cuts_path)
|
||||
if i > 1:
|
||||
break
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -1,11 +1,14 @@
|
||||
from typing import List, Tuple # Added for type hints
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
from typing import List, Tuple # Added for type hints
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
import logging
|
||||
|
||||
|
||||
class EncoderProjector(nn.Module):
|
||||
"""
|
||||
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
|
||||
@ -69,7 +72,8 @@ class SPEECH_LLM(nn.Module):
|
||||
self.codec_lm = codec_lm
|
||||
if self.codec_lm:
|
||||
self.speech_token_projector = nn.Linear(
|
||||
self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size
|
||||
self.llm.config.hidden_size + self.llm.config.hidden_size,
|
||||
self.codec_lm.config.hidden_size,
|
||||
)
|
||||
self.codec_lm_head = nn.Linear(
|
||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||
@ -89,6 +93,7 @@ class SPEECH_LLM(nn.Module):
|
||||
multidim_average="global",
|
||||
ignore_index=IGNORE_TOKEN_ID,
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_speech_features(
|
||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||
):
|
||||
@ -274,68 +279,115 @@ class SPEECH_LLM(nn.Module):
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
input_seq_len = attention_mask.sum(dim=1) # shape, B
|
||||
text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], []
|
||||
input_seq_len = attention_mask.sum(dim=1) # shape, B
|
||||
(
|
||||
text_label_start_index_list,
|
||||
text_input_start_index_list,
|
||||
input_question_len_list,
|
||||
) = ([], [], [])
|
||||
for i in range(labels.shape[0]):
|
||||
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
|
||||
input_embeds_start_index = input_embeds_valid_index[0]
|
||||
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
|
||||
text_labels_start_index = text_labels_valid_index[0]
|
||||
|
||||
assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
|
||||
assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
|
||||
assert (
|
||||
input_seq_len[i]
|
||||
== input_embeds_valid_index[-1] - input_embeds_start_index + 1
|
||||
), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
|
||||
assert (
|
||||
input_embeds_valid_index[-1] == text_labels_valid_index[-1]
|
||||
), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
|
||||
input_question_len = text_labels_start_index - input_embeds_start_index
|
||||
assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i]
|
||||
assert (
|
||||
input_question_len
|
||||
+ text_labels_valid_index[-1]
|
||||
- text_labels_start_index
|
||||
+ 1
|
||||
== input_seq_len[i]
|
||||
)
|
||||
text_label_start_index_list.append(text_labels_start_index)
|
||||
text_input_start_index_list.append(input_embeds_start_index)
|
||||
input_question_len_list.append(input_question_len)
|
||||
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
text_loss = model_outputs.loss
|
||||
delay_step = 1
|
||||
# prepare codec lm inputs
|
||||
audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)]
|
||||
audio_codes_lens = [
|
||||
len(x) + input_question_len_list[i] + delay_step + 1
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
max_len_speech_codec = max(audio_codes_lens)
|
||||
|
||||
if self.codec_lm_padding_side == "right":
|
||||
audio_codes = [
|
||||
[self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
|
||||
[self.codec_lm.config.mask_token_id]
|
||||
* (input_question_len_list[i] + delay_step)
|
||||
+ [self.codec_lm.config.bos_token_id]
|
||||
+ x
|
||||
+ [self.codec_lm.config.pad_token_id]
|
||||
* (max_len_speech_codec - audio_codes_lens[i])
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
audio_labels = [
|
||||
[self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
|
||||
[self.codec_lm.config.pad_token_id]
|
||||
* (input_question_len_list[i] + delay_step)
|
||||
+ x
|
||||
+ [self.codec_lm.config.eos_token_id]
|
||||
+ [self.codec_lm.config.pad_token_id]
|
||||
* (max_len_speech_codec - audio_codes_lens[i])
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
audio_codes = [
|
||||
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x
|
||||
[self.codec_lm.config.pad_token_id]
|
||||
* (max_len_speech_codec - audio_codes_lens[i])
|
||||
+ [self.codec_lm.config.mask_token_id]
|
||||
* (input_question_len_list[i] + delay_step)
|
||||
+ [self.codec_lm.config.bos_token_id]
|
||||
+ x
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
audio_labels = [
|
||||
[self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id]
|
||||
[self.codec_lm.config.pad_token_id]
|
||||
* (max_len_speech_codec - audio_codes_lens[i])
|
||||
+ [self.codec_lm.config.pad_token_id]
|
||||
* (input_question_len_list[i] + delay_step)
|
||||
+ x
|
||||
+ [self.codec_lm.config.eos_token_id]
|
||||
for i, x in enumerate(speech_codec_ids)
|
||||
]
|
||||
audio_codes = torch.tensor(
|
||||
audio_codes,
|
||||
dtype=torch.int64,
|
||||
device=input_ids.device
|
||||
audio_codes, dtype=torch.int64, device=input_ids.device
|
||||
)
|
||||
audio_labels = torch.tensor(
|
||||
audio_labels,
|
||||
dtype=torch.int64,
|
||||
device=input_ids.device
|
||||
audio_labels, dtype=torch.int64, device=input_ids.device
|
||||
)
|
||||
|
||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||
|
||||
|
||||
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
||||
for i in range(len(text_label_start_index_list)):
|
||||
text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1]
|
||||
text_last_hidden = model_outputs.hidden_states[-1][
|
||||
i,
|
||||
text_input_start_index_list[i] : text_input_start_index_list[i]
|
||||
+ input_seq_len[i]
|
||||
- 1,
|
||||
]
|
||||
text_last_hidden_lists.append(text_last_hidden)
|
||||
text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos
|
||||
text_embed = inputs_embeds[
|
||||
i,
|
||||
text_input_start_index_list[i]
|
||||
+ 1 : text_input_start_index_list[i]
|
||||
+ input_seq_len[i],
|
||||
] # exclude bos
|
||||
text_embeds_list.append(text_embed)
|
||||
|
||||
text_input_embeds = torch.cat(
|
||||
@ -344,22 +396,34 @@ class SPEECH_LLM(nn.Module):
|
||||
text_embed,
|
||||
],
|
||||
dim=-1,
|
||||
)# shape, T, D1 + D2
|
||||
text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec
|
||||
) # shape, T, D1 + D2
|
||||
text_input_embeds = self.speech_token_projector(
|
||||
text_input_embeds
|
||||
) # shape, T, D_codec
|
||||
text_input_embeds_list.append(text_input_embeds)
|
||||
|
||||
|
||||
for i in range(audio_embeddings.shape[0]):
|
||||
text_input_embeds = text_input_embeds_list[i]
|
||||
if self.codec_lm_padding_side == "right":
|
||||
audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds
|
||||
audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0]
|
||||
start_idx = torch.where(
|
||||
audio_codes[i] == self.codec_lm.config.mask_token_id
|
||||
)[0][0]
|
||||
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
|
||||
assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
||||
assert (
|
||||
start_idx == start_idx_re_compute
|
||||
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
||||
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
|
||||
text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx]
|
||||
logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}")
|
||||
audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds
|
||||
text_input_embeds = text_input_embeds[
|
||||
: audio_embeddings.shape[1] - start_idx
|
||||
]
|
||||
logging.warning(
|
||||
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
|
||||
)
|
||||
audio_embeddings[
|
||||
i, start_idx : start_idx + text_input_embeds.shape[0]
|
||||
] += text_input_embeds
|
||||
|
||||
speech_outputs = self.codec_lm(
|
||||
attention_mask=audio_attention_mask,
|
||||
@ -369,8 +433,10 @@ class SPEECH_LLM(nn.Module):
|
||||
)
|
||||
last_hidden_state = speech_outputs.hidden_states[-1].clone()
|
||||
|
||||
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
|
||||
audio_logits = audio_logits.contiguous().view(-1, self.codec_lm.config.vocab_size)
|
||||
audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
|
||||
audio_logits = audio_logits.contiguous().view(
|
||||
-1, self.codec_lm.config.vocab_size
|
||||
)
|
||||
audio_labels = audio_labels.contiguous().view(-1)
|
||||
audio_labels = audio_labels.masked_fill(
|
||||
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
|
||||
@ -378,7 +444,6 @@ class SPEECH_LLM(nn.Module):
|
||||
codec_loss = self.loss_fct(audio_logits, audio_labels)
|
||||
audio_preds = torch.argmax(audio_logits, -1)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc = compute_accuracy(
|
||||
@ -392,12 +457,11 @@ class SPEECH_LLM(nn.Module):
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
audio_topk_acc = self.audio_accuracy_metric(
|
||||
audio_logits.detach(),
|
||||
audio_labels.detach()).item()
|
||||
|
||||
audio_logits.detach(), audio_labels.detach()
|
||||
).item()
|
||||
|
||||
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
|
||||
|
||||
|
||||
def decode(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
@ -453,12 +517,12 @@ class SPEECH_LLM(nn.Module):
|
||||
def decode_with_speech_output(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None, # Prompt input_ids
|
||||
attention_mask: torch.Tensor = None, # Prompt attention_mask
|
||||
input_ids: torch.LongTensor = None, # Prompt input_ids
|
||||
attention_mask: torch.Tensor = None, # Prompt attention_mask
|
||||
max_text_new_tokens: int = 1024,
|
||||
max_speech_new_tokens: int = 1024, # Max length for speech tokens
|
||||
llm_kwargs: dict = None, # Kwargs for text LLM generate
|
||||
codec_lm_kwargs: dict = None # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
|
||||
max_speech_new_tokens: int = 1024, # Max length for speech tokens
|
||||
llm_kwargs: dict = None, # Kwargs for text LLM generate
|
||||
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
|
||||
) -> Tuple[torch.LongTensor, List[List[int]]]:
|
||||
"""
|
||||
Generates text and corresponding speech tokens using the revised logic.
|
||||
@ -479,16 +543,22 @@ class SPEECH_LLM(nn.Module):
|
||||
the generated speech codec tokens for a batch item.
|
||||
"""
|
||||
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
|
||||
if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
|
||||
raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
|
||||
if (
|
||||
not self.codec_lm
|
||||
or not self.speech_token_projector
|
||||
or not self.codec_lm_head
|
||||
):
|
||||
raise ValueError(
|
||||
"codec_lm and associated layers must be initialized to generate speech output."
|
||||
)
|
||||
|
||||
device = next(self.parameters()).device # Use model's device
|
||||
device = next(self.parameters()).device # Use model's device
|
||||
batch_size = fbank.shape[0]
|
||||
|
||||
# --- 1. Prepare Prompt Embeddings ---
|
||||
encoder_outs = self.encoder(fbank)
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
|
||||
speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
|
||||
|
||||
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -511,12 +581,12 @@ class SPEECH_LLM(nn.Module):
|
||||
"eos_token_id": self.llm.config.eos_token_id,
|
||||
"pad_token_id": self.llm.config.pad_token_id,
|
||||
"num_beams": 1,
|
||||
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
|
||||
"do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
|
||||
"top_p": 0.5,
|
||||
"top_k": 20,
|
||||
"repetition_penalty": 1.1,
|
||||
"temperature": 0.7,
|
||||
**(llm_kwargs or {}) # User-provided kwargs override defaults
|
||||
**(llm_kwargs or {}), # User-provided kwargs override defaults
|
||||
}
|
||||
|
||||
text_outputs = self.llm.generate(
|
||||
@ -525,17 +595,22 @@ class SPEECH_LLM(nn.Module):
|
||||
max_new_tokens=max_text_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_hidden_states=True,
|
||||
**final_llm_kwargs
|
||||
**final_llm_kwargs,
|
||||
)
|
||||
delay_step = 1
|
||||
generated_text_ids = text_outputs.sequences # [B, S_full]
|
||||
generated_text_ids = text_outputs.sequences # [B, S_full]
|
||||
eos_token_id = self.llm.config.eos_token_id
|
||||
eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D
|
||||
assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
||||
eos_token_embedding = self.llm.get_input_embeddings()(
|
||||
torch.tensor([[eos_token_id]], device=device)
|
||||
) # 1,D
|
||||
assert (
|
||||
generated_text_ids[0, -1] == eos_token_id
|
||||
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
||||
thinker_token_embeds_org = [
|
||||
token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||
token_hidden_states[0].to(self.llm.device)
|
||||
for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
|
||||
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
|
||||
first_thinker_token_embed = torch.cat(
|
||||
[
|
||||
thinker_token_embeds_org[0][:, 1:],
|
||||
@ -544,19 +619,27 @@ class SPEECH_LLM(nn.Module):
|
||||
dim=1,
|
||||
)
|
||||
|
||||
thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding]
|
||||
thinker_token_embeds = (
|
||||
[first_thinker_token_embed]
|
||||
+ thinker_token_embeds_org[2:]
|
||||
+ [eos_token_embedding]
|
||||
)
|
||||
thinker_hidden_states = [
|
||||
token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
|
||||
token_hidden_states[-1].to(self.llm.device)
|
||||
for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||
thinker_reply_part = [torch.cat(
|
||||
[
|
||||
thinker_hidden_state,
|
||||
thinker_token_embed,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:])
|
||||
thinker_reply_part = [
|
||||
torch.cat(
|
||||
[
|
||||
thinker_hidden_state,
|
||||
thinker_token_embed,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
for thinker_hidden_state, thinker_token_embed in zip(
|
||||
thinker_hidden_states[1:], thinker_token_embeds[1:]
|
||||
)
|
||||
]
|
||||
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
|
||||
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
|
||||
@ -568,26 +651,35 @@ class SPEECH_LLM(nn.Module):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
|
||||
thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
|
||||
|
||||
thinker_prompt_part = self.speech_token_projector(
|
||||
thinker_prompt_part
|
||||
) # [B, S_full, D_codec]
|
||||
thinker_reply_part = self.speech_token_projector(
|
||||
thinker_reply_part
|
||||
) # [B, S_full, D_codec]
|
||||
|
||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||
talker_input_ids = torch.full(
|
||||
(batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device
|
||||
(batch_size, thinker_prompt_part_seq_len + delay_step + 1),
|
||||
self.codec_lm.config.mask_token_id,
|
||||
dtype=torch.long,
|
||||
device=self.llm.device,
|
||||
)
|
||||
talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
|
||||
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
||||
talker_input_ids
|
||||
) # [B, S_full, D_codec]
|
||||
thinker_input_embeds = torch.cat(
|
||||
[
|
||||
thinker_prompt_part,
|
||||
thinker_reply_part[:, :delay_step + 1, :],
|
||||
thinker_reply_part[:, : delay_step + 1, :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
talker_inputs_embeds += thinker_input_embeds
|
||||
thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec]
|
||||
|
||||
thinker_reply_part = thinker_reply_part[
|
||||
:, delay_step + 1 :, :
|
||||
] # [B, S_full, D_codec]
|
||||
|
||||
past_key_values = None
|
||||
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
||||
@ -599,10 +691,14 @@ class SPEECH_LLM(nn.Module):
|
||||
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
||||
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
||||
if t > 0:
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
||||
next_token_ids
|
||||
) # [B, 1, D_codec]
|
||||
if thinker_reply_part.shape[1] > 0:
|
||||
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
||||
thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step
|
||||
thinker_reply_part = thinker_reply_part[
|
||||
:, 1:, :
|
||||
] # Remove the first token for next step
|
||||
# # Add the projected text embedding corresponding to the current timestep `t`
|
||||
# if t < text_context_len:
|
||||
# # Text context from the full generated text sequence
|
||||
@ -611,20 +707,24 @@ class SPEECH_LLM(nn.Module):
|
||||
# else:
|
||||
# # No more text context to add
|
||||
# inputs_embeds = current_speech_embeds
|
||||
|
||||
|
||||
# Forward pass through codec LM for one step
|
||||
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
|
||||
codec_outputs = self.codec_lm(
|
||||
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
|
||||
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
# No attention mask needed here when using past_key_values and single token input
|
||||
)
|
||||
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here
|
||||
last_token_hidden_state = codec_outputs.hidden_states[-1][
|
||||
:, -1, :
|
||||
] # [B, D_codec] #TODO: check shape here
|
||||
# Get logits for the *last* token generated in this step
|
||||
next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
|
||||
next_token_logits = self.codec_lm_head(
|
||||
last_token_hidden_state
|
||||
) # Use -1 index
|
||||
# suppress tokens between 4096:len(vocab)-3
|
||||
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
||||
next_token_ids = topk_sampling(
|
||||
@ -634,11 +734,14 @@ class SPEECH_LLM(nn.Module):
|
||||
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
||||
break
|
||||
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
||||
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||
generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0])
|
||||
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||
generated_speech_tokens_list.append(
|
||||
next_token_ids.squeeze(1).cpu().tolist()[0]
|
||||
)
|
||||
# --- 6. Return Results ---
|
||||
return generated_text_ids, generated_speech_tokens_list
|
||||
|
||||
|
||||
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
|
||||
@ -717,4 +820,4 @@ def top_k_top_p_filtering(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
return logits
|
@ -1,13 +1,12 @@
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
|
||||
from lhotse import validate
|
||||
from lhotse.cut import CutSet
|
||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||
from lhotse.utils import compute_num_frames, ifnone
|
||||
from lhotse.workarounds import Hdf5MemoryIssueFix
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
|
||||
|
||||
class K2SpeechRecognitionDataset(torch.utils.data.Dataset):
|
@ -1,26 +1,25 @@
|
||||
# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import gradio as gr
|
||||
import soundfile as sf
|
||||
|
||||
import gradio.processing_utils as processing_utils
|
||||
import tempfile
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||
from gradio_client import utils as client_utils
|
||||
|
||||
from argparse import ArgumentParser
|
||||
import whisper
|
||||
import torch
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
import sherpa_onnx
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
import sys
|
||||
sys.path.append('/workspace/CosyVoice/third_party/Matcha-TTS')
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import gradio as gr
|
||||
import gradio.processing_utils as processing_utils
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import whisper
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from gradio_client import utils as client_utils
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
|
||||
def get_model(params, device="cuda"):
|
||||
@ -88,7 +87,7 @@ def get_model(params, device="cuda"):
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch.float16
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
codec_lm.vocab_size = codec_vocab_size
|
||||
@ -102,12 +101,10 @@ def get_model(params, device="cuda"):
|
||||
llm,
|
||||
encoder_projector,
|
||||
codec_lm,
|
||||
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
||||
codec_lm_padding_side="left" if params.use_flash_attn else "right",
|
||||
)
|
||||
|
||||
checkpoint = torch.load(
|
||||
f"{params.checkpoint_path}", map_location="cpu"
|
||||
)
|
||||
checkpoint = torch.load(f"{params.checkpoint_path}", map_location="cpu")
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
model.to(device)
|
||||
@ -122,27 +119,37 @@ def audio_decode_cosyvoice(audio_tokens, codec_decoder):
|
||||
Args:
|
||||
audio_tokens (list): List of audio tokens to be processed.
|
||||
codec_decoder: Codec decoder for generating audio.
|
||||
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Generated audio waveform.
|
||||
"""
|
||||
flow_embedding = codec_decoder.frontend.spk2info['中文女']['embedding']
|
||||
flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
|
||||
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
|
||||
prompt_speech_feat = torch.zeros(1, 0, 80)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
|
||||
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
|
||||
prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
|
||||
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
|
||||
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
|
||||
embedding=flow_embedding.to(codec_decoder.model.device),
|
||||
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||
token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
|
||||
prompt_token_len=torch.tensor(
|
||||
[flow_prompt_speech_token.shape[1]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
|
||||
prompt_feat_len=torch.tensor(
|
||||
[prompt_speech_feat.shape[1]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
embedding=flow_embedding.to(codec_decoder.model.device),
|
||||
flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
|
||||
)
|
||||
|
||||
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0))
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
return audio_hat
|
||||
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer,
|
||||
@ -178,28 +185,14 @@ def preprocess(
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
|
||||
|
||||
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
def format_history(history: list):
|
||||
messages = []
|
||||
for item in history:
|
||||
if isinstance(item["content"], str):
|
||||
messages.append({"role": item['role'], "content": item['content']})
|
||||
# elif item["role"] == "user" and (isinstance(item["content"], list) or
|
||||
# isinstance(item["content"], tuple)):
|
||||
# file_path = item["content"][0]
|
||||
# # TODO: check if the file_path's transcript is already in the history
|
||||
# mime_type = client_utils.get_mimetype(file_path)
|
||||
# if mime_type.startswith("audio"):
|
||||
# messages.append({
|
||||
# "role":
|
||||
# item['role'],
|
||||
# "content": item["content"][1] # append audio transcript here
|
||||
# })
|
||||
print('predict history: ', messages)
|
||||
# messages = messages[-2:] # TODO: WAR: add history later
|
||||
messages.append({"role": item["role"], "content": item["content"]})
|
||||
return messages
|
||||
|
||||
def decode(
|
||||
@ -217,9 +210,8 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
dtype = torch.float32
|
||||
device = model.llm.device
|
||||
|
||||
feature = feature.to(device, dtype=dtype)#.transpose(1, 2)
|
||||
# assert feature.shape[2] == 80
|
||||
|
||||
feature = feature.to(device, dtype=dtype)
|
||||
|
||||
input_ids, attention_mask = preprocess([messages], tokenizer)
|
||||
|
||||
generated_ids, audio_tokens = model.decode_with_speech_output(
|
||||
@ -227,26 +219,21 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
)
|
||||
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
# print('hyps: ', hyps, 23333333333333333333333333)
|
||||
|
||||
yield {"type": "text", "data": hyps[0]}
|
||||
# yield {"type": "text", "data": hyps}
|
||||
|
||||
audio_tokens = [token for token in audio_tokens if token < 4096]
|
||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
||||
audio = audio_hat.squeeze(0).cpu().numpy()
|
||||
# sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050)
|
||||
audio = audio_hat.squeeze(0).cpu().numpy()
|
||||
audio = np.array(audio * 32767).astype(np.int16)
|
||||
# yield {"type": "audio", "data": (22050, audio)}
|
||||
# with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
|
||||
# sf.write(tmpfile.name, audio, 22050, format="WAV")
|
||||
# audio_path = tmpfile.name
|
||||
wav_io = io.BytesIO()
|
||||
sf.write(wav_io, audio, samplerate=22050, format="WAV")
|
||||
wav_io.seek(0)
|
||||
wav_bytes = wav_io.getvalue()
|
||||
audio_path = processing_utils.save_bytes_to_cache(
|
||||
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
|
||||
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE
|
||||
)
|
||||
|
||||
yield {"type": "audio", "data": audio_path}
|
||||
|
||||
@ -259,25 +246,27 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
gr.update(visible=True), # stop_btn
|
||||
)
|
||||
print(2333, history, audio)
|
||||
history.append({"role": "user", "content": (audio,)})
|
||||
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
|
||||
history.append({"role": "user", "content": (audio,)})
|
||||
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
formatted_history = format_history(history=history) # only keep string text format
|
||||
formatted_history = format_history(
|
||||
history=history
|
||||
) # only keep string text format
|
||||
|
||||
assert audio is not None
|
||||
audio_transcript = get_transcript(
|
||||
audio,
|
||||
asr_model,
|
||||
)
|
||||
print('audio_transcript: ', audio_transcript)
|
||||
history[-2]["content"] = audio_transcript
|
||||
|
||||
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
|
||||
fbank = fbank.unsqueeze(0)
|
||||
assert fbank.ndim == 3
|
||||
|
||||
# history.append({"role": "assistant", "content": ""})
|
||||
for chunk in decode(model, token2wav_model, tokenizer, fbank, formatted_history):
|
||||
for chunk in decode(
|
||||
model, token2wav_model, tokenizer, fbank, formatted_history
|
||||
):
|
||||
if chunk["type"] == "text":
|
||||
history[-1]["content"] = chunk["data"]
|
||||
yield (
|
||||
@ -287,10 +276,9 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
gr.update(visible=True), # stop_btn
|
||||
)
|
||||
if chunk["type"] == "audio":
|
||||
history.append({
|
||||
"role": "assistant",
|
||||
"content": gr.Audio(chunk["data"])
|
||||
})
|
||||
history.append(
|
||||
{"role": "assistant", "content": gr.Audio(chunk["data"])}
|
||||
)
|
||||
|
||||
# Final yield
|
||||
yield (
|
||||
@ -304,8 +292,7 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
with gr.Tab("Online"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
microphone = gr.Audio(sources=['microphone'],
|
||||
type="filepath")
|
||||
microphone = gr.Audio(sources=["microphone"], type="filepath")
|
||||
submit_btn = gr.Button("Submit", variant="primary")
|
||||
stop_btn = gr.Button("Stop", visible=False)
|
||||
clear_btn = gr.Button("Clear History")
|
||||
@ -315,64 +302,80 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
def clear_history():
|
||||
return [], gr.update(value=None)
|
||||
|
||||
submit_event = submit_btn.click(fn=media_predict,
|
||||
inputs=[
|
||||
microphone,
|
||||
media_chatbot,
|
||||
],
|
||||
outputs=[
|
||||
microphone,
|
||||
media_chatbot, submit_btn,
|
||||
stop_btn
|
||||
])
|
||||
submit_event = submit_btn.click(
|
||||
fn=media_predict,
|
||||
inputs=[
|
||||
microphone,
|
||||
media_chatbot,
|
||||
],
|
||||
outputs=[microphone, media_chatbot, submit_btn, stop_btn],
|
||||
)
|
||||
stop_btn.click(
|
||||
fn=lambda:
|
||||
(gr.update(visible=True), gr.update(visible=False)),
|
||||
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
|
||||
inputs=None,
|
||||
outputs=[submit_btn, stop_btn],
|
||||
cancels=[submit_event],
|
||||
queue=False)
|
||||
clear_btn.click(fn=clear_history,
|
||||
inputs=None,
|
||||
outputs=[media_chatbot, microphone])
|
||||
queue=False,
|
||||
)
|
||||
clear_btn.click(
|
||||
fn=clear_history, inputs=None, outputs=[media_chatbot, microphone]
|
||||
)
|
||||
|
||||
demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100,
|
||||
ssr_mode=False,
|
||||
share=args.share,
|
||||
inbrowser=args.inbrowser,
|
||||
server_port=args.server_port,
|
||||
server_name=args.server_name,)
|
||||
demo.queue(default_concurrency_limit=100, max_size=100).launch(
|
||||
max_threads=100,
|
||||
ssr_mode=False,
|
||||
share=args.share,
|
||||
inbrowser=args.inbrowser,
|
||||
server_port=args.server_port,
|
||||
server_name=args.server_name,
|
||||
)
|
||||
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument('--checkpoint-path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Checkpoint name or path, default to %(default)r')
|
||||
parser.add_argument('--token2wav-path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Token2Wav path, default to %(default)r')
|
||||
parser.add_argument('--asr-model-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='ASR model dir, default to %(default)r')
|
||||
parser.add_argument('--flash-attn2',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Enable flash_attention_2 when loading the model.')
|
||||
parser.add_argument('--share',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Create a publicly shareable link for the interface.')
|
||||
parser.add_argument('--inbrowser',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Automatically launch the interface in a new tab on the default browser.')
|
||||
parser.add_argument('--server-port', type=int, default=8001, help='Demo server port.')
|
||||
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
|
||||
parser.add_argument(
|
||||
"--checkpoint-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Checkpoint name or path, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token2wav-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Token2Wav path, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asr-model-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="ASR model dir, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash-attn2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable flash_attention_2 when loading the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--share",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Create a publicly shareable link for the interface.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inbrowser",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Automatically launch the interface in a new tab on the default browser.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-port", type=int, default=8001, help="Demo server port."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-name", type=str, default="127.0.0.1", help="Demo server name."
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@ -401,6 +404,7 @@ def read_wave(wave_filename: str):
|
||||
|
||||
return samples_float32, sample_rate
|
||||
|
||||
|
||||
def get_transcript(audio_path, recognizer):
|
||||
samples, sample_rate = read_wave(audio_path)
|
||||
s = recognizer.create_stream()
|
||||
@ -408,10 +412,13 @@ def get_transcript(audio_path, recognizer):
|
||||
recognizer.decode_streams([s])
|
||||
return s.result.text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _get_args()
|
||||
model, tokenizer = get_model(args)
|
||||
token2wav = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False)
|
||||
token2wav = CosyVoice(
|
||||
args.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
|
||||
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=f"{args.asr_model_dir}/model.int8.onnx",
|
||||
@ -423,4 +430,4 @@ if __name__ == "__main__":
|
||||
debug=False,
|
||||
)
|
||||
|
||||
_launch_demo(args, model, tokenizer, token2wav, asr_model)
|
||||
_launch_demo(args, model, tokenizer, token2wav, asr_model)
|
Loading…
x
Reference in New Issue
Block a user