diff --git a/egs/speech_llm/SPEECH2SPEECH/README.md b/egs/speech_llm/SPEECH2SPEECH/README.md
new file mode 100644
index 000000000..e4738eeef
--- /dev/null
+++ b/egs/speech_llm/SPEECH2SPEECH/README.md
@@ -0,0 +1,89 @@
+
+# Introduction
+
+This recipe includes scripts for training speech2speech models.
+
+# SPEECH2SPEECH
+
+The following table lists the folders for different tasks.
+
+|Recipe | Speech Input | Speech Output | Comment|
+|--------------|--------------|---------------|--------|
+|Qwen-omni like| Continuous Embeddins| Cosyvoice1 50Hz Single-codebook Token | Text-driven; using Thinker LLM for text token, small Talker LLM for speech token |
+
+### [Qwen-omni like Speech2speech Recipe](./qwen_omni)
+
+[Qwen2.5-Omni](https://github.com/QwenLM/Qwen2.5-Omni) style model using [worstchan/Belle_1.4M-SLAM-Omni](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni) dataset.
+
+
+
+
+
+
+
+Command for training is:
+```bash
+pip install -r whisper_llm_zh/requirements.txt
+
+pip install huggingface_hub['cli']
+mkdir -p models/whisper models/qwen
+
+# For aishell fine-tuned whisper model
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
+# For multi-hans fine-tuned whisper model
+# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
+
+# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
+huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
+
+# First, we only train the projector and freeze other modules.
+torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
+ --max-duration 200 \
+ --exp-dir ./whisper_llm_zh/exp_test \
+ --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
+ --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
+ --manifest-dir data/fbank \
+ --deepspeed \
+ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
+ --use-flash-attn True \
+ --use-lora False --unfreeze-llm False
+
+# Then we jointly train the projector and LLM LoRA modules.
+torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
+ --max-duration 200 \
+ --exp-dir ./whisper_llm_zh/exp_test \
+ --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
+ --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
+ --manifest-dir data/fbank \
+ --deepspeed \
+ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
+ --use-flash-attn True \
+ --use-lora True --unfreeze-llm True
+ --pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
+```
+
+Command for decoding:
+```bash
+mkdir -p models/whisper models/qwen models/checkpoint
+huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
+
+# For aishell fine-tuned whisper model
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
+# For multi-hans fine-tuned whisper model
+# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
+
+huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
+
+mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
+ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
+
+python3 ./whisper_llm_zh/decode.py \
+ --max-duration 80 \
+ --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
+ --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
+ --llm-path-or-name models/qwen \
+ --epoch 999 --avg 1 \
+ --manifest-dir data/fbank \
+ --use-flash-attn True \
+ --use-lora True --dataset aishell
+```
diff --git a/egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg b/egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg
new file mode 100644
index 000000000..d708bb256
Binary files /dev/null and b/egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg differ
diff --git a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py
index b01a35c7d..4bc5e5a82 100755
--- a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py
+++ b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py
@@ -17,6 +17,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Usage:
+ python3 local/compute_whisper_fbank.py \
+ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
+ --out-dir data/fbank \
+ --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \
+ --audio-key question_audio --text-key answer \
+ --prefix ultrachat
+"""
+
import argparse
import logging
@@ -126,7 +136,7 @@ def compute_fbank(args):
num_digits = 5
for i in range(num_shards):
shard = dataset.shard(num_shards, i)
- shard = shard.take(10) # for testing
+ # shard = shard.take(10) # for testing
logging.info(
f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}"
)
@@ -159,8 +169,6 @@ def compute_fbank(args):
logging.info(f"Saving to {cuts_path}")
# see https://github.com/lhotse-speech/lhotse/issues/1125
cut_set.drop_recordings().to_file(cuts_path)
- if i > 1:
- break
def main():
diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py
similarity index 100%
rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py
rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py
diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py
similarity index 100%
rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py
rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py
diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/ds_config_zero1.json
similarity index 100%
rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json
rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/ds_config_zero1.json
diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/label_smoothing.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/label_smoothing.py
similarity index 100%
rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/label_smoothing.py
rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/label_smoothing.py
diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py
similarity index 77%
rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py
rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py
index 0cc93c237..97870337d 100644
--- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py
@@ -1,11 +1,14 @@
+from typing import List, Tuple # Added for type hints
+
import torch
from torch import nn
-from transformers.trainer_pt_utils import LabelSmoother
-from typing import List, Tuple # Added for type hints
from torchmetrics.classification import MulticlassAccuracy
+from transformers.trainer_pt_utils import LabelSmoother
+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import logging
+
class EncoderProjector(nn.Module):
"""
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
@@ -69,7 +72,8 @@ class SPEECH_LLM(nn.Module):
self.codec_lm = codec_lm
if self.codec_lm:
self.speech_token_projector = nn.Linear(
- self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size
+ self.llm.config.hidden_size + self.llm.config.hidden_size,
+ self.codec_lm.config.hidden_size,
)
self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
@@ -89,6 +93,7 @@ class SPEECH_LLM(nn.Module):
multidim_average="global",
ignore_index=IGNORE_TOKEN_ID,
)
+
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
):
@@ -274,68 +279,115 @@ class SPEECH_LLM(nn.Module):
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
- input_seq_len = attention_mask.sum(dim=1) # shape, B
- text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], []
+ input_seq_len = attention_mask.sum(dim=1) # shape, B
+ (
+ text_label_start_index_list,
+ text_input_start_index_list,
+ input_question_len_list,
+ ) = ([], [], [])
for i in range(labels.shape[0]):
input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0]
input_embeds_start_index = input_embeds_valid_index[0]
text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0]
text_labels_start_index = text_labels_valid_index[0]
- assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
- assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
+ assert (
+ input_seq_len[i]
+ == input_embeds_valid_index[-1] - input_embeds_start_index + 1
+ ), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}"
+ assert (
+ input_embeds_valid_index[-1] == text_labels_valid_index[-1]
+ ), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}"
input_question_len = text_labels_start_index - input_embeds_start_index
- assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i]
+ assert (
+ input_question_len
+ + text_labels_valid_index[-1]
+ - text_labels_start_index
+ + 1
+ == input_seq_len[i]
+ )
text_label_start_index_list.append(text_labels_start_index)
text_input_start_index_list.append(input_embeds_start_index)
input_question_len_list.append(input_question_len)
model_outputs = self.llm(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ labels=labels,
+ output_hidden_states=True,
)
text_loss = model_outputs.loss
delay_step = 1
# prepare codec lm inputs
- audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)]
+ audio_codes_lens = [
+ len(x) + input_question_len_list[i] + delay_step + 1
+ for i, x in enumerate(speech_codec_ids)
+ ]
max_len_speech_codec = max(audio_codes_lens)
if self.codec_lm_padding_side == "right":
audio_codes = [
- [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.mask_token_id]
+ * (input_question_len_list[i] + delay_step)
+ + [self.codec_lm.config.bos_token_id]
+ + x
+ + [self.codec_lm.config.pad_token_id]
+ * (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
- [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i])
+ [self.codec_lm.config.pad_token_id]
+ * (input_question_len_list[i] + delay_step)
+ + x
+ + [self.codec_lm.config.eos_token_id]
+ + [self.codec_lm.config.pad_token_id]
+ * (max_len_speech_codec - audio_codes_lens[i])
for i, x in enumerate(speech_codec_ids)
]
elif self.codec_lm_padding_side == "left":
audio_codes = [
- [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x
+ [self.codec_lm.config.pad_token_id]
+ * (max_len_speech_codec - audio_codes_lens[i])
+ + [self.codec_lm.config.mask_token_id]
+ * (input_question_len_list[i] + delay_step)
+ + [self.codec_lm.config.bos_token_id]
+ + x
for i, x in enumerate(speech_codec_ids)
]
audio_labels = [
- [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id]
+ [self.codec_lm.config.pad_token_id]
+ * (max_len_speech_codec - audio_codes_lens[i])
+ + [self.codec_lm.config.pad_token_id]
+ * (input_question_len_list[i] + delay_step)
+ + x
+ + [self.codec_lm.config.eos_token_id]
for i, x in enumerate(speech_codec_ids)
]
audio_codes = torch.tensor(
- audio_codes,
- dtype=torch.int64,
- device=input_ids.device
+ audio_codes, dtype=torch.int64, device=input_ids.device
)
audio_labels = torch.tensor(
- audio_labels,
- dtype=torch.int64,
- device=input_ids.device
+ audio_labels, dtype=torch.int64, device=input_ids.device
)
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
-
+
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
for i in range(len(text_label_start_index_list)):
- text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1]
+ text_last_hidden = model_outputs.hidden_states[-1][
+ i,
+ text_input_start_index_list[i] : text_input_start_index_list[i]
+ + input_seq_len[i]
+ - 1,
+ ]
text_last_hidden_lists.append(text_last_hidden)
- text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos
+ text_embed = inputs_embeds[
+ i,
+ text_input_start_index_list[i]
+ + 1 : text_input_start_index_list[i]
+ + input_seq_len[i],
+ ] # exclude bos
text_embeds_list.append(text_embed)
text_input_embeds = torch.cat(
@@ -344,22 +396,34 @@ class SPEECH_LLM(nn.Module):
text_embed,
],
dim=-1,
- )# shape, T, D1 + D2
- text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec
+ ) # shape, T, D1 + D2
+ text_input_embeds = self.speech_token_projector(
+ text_input_embeds
+ ) # shape, T, D_codec
text_input_embeds_list.append(text_input_embeds)
-
+
for i in range(audio_embeddings.shape[0]):
text_input_embeds = text_input_embeds_list[i]
if self.codec_lm_padding_side == "right":
- audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds
+ audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds
elif self.codec_lm_padding_side == "left":
- start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0]
+ start_idx = torch.where(
+ audio_codes[i] == self.codec_lm.config.mask_token_id
+ )[0][0]
start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0]
- assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
+ assert (
+ start_idx == start_idx_re_compute
+ ), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
- text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx]
- logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}")
- audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds
+ text_input_embeds = text_input_embeds[
+ : audio_embeddings.shape[1] - start_idx
+ ]
+ logging.warning(
+ f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
+ )
+ audio_embeddings[
+ i, start_idx : start_idx + text_input_embeds.shape[0]
+ ] += text_input_embeds
speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask,
@@ -369,8 +433,10 @@ class SPEECH_LLM(nn.Module):
)
last_hidden_state = speech_outputs.hidden_states[-1].clone()
- audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
- audio_logits = audio_logits.contiguous().view(-1, self.codec_lm.config.vocab_size)
+ audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size
+ audio_logits = audio_logits.contiguous().view(
+ -1, self.codec_lm.config.vocab_size
+ )
audio_labels = audio_labels.contiguous().view(-1)
audio_labels = audio_labels.masked_fill(
audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID
@@ -378,7 +444,6 @@ class SPEECH_LLM(nn.Module):
codec_loss = self.loss_fct(audio_logits, audio_labels)
audio_preds = torch.argmax(audio_logits, -1)
-
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(
@@ -392,12 +457,11 @@ class SPEECH_LLM(nn.Module):
ignore_label=IGNORE_TOKEN_ID,
)
audio_topk_acc = self.audio_accuracy_metric(
- audio_logits.detach(),
- audio_labels.detach()).item()
-
+ audio_logits.detach(), audio_labels.detach()
+ ).item()
return text_loss, acc, codec_loss, audio_acc, audio_topk_acc
-
+
def decode(
self,
fbank: torch.Tensor = None,
@@ -453,12 +517,12 @@ class SPEECH_LLM(nn.Module):
def decode_with_speech_output(
self,
fbank: torch.Tensor = None,
- input_ids: torch.LongTensor = None, # Prompt input_ids
- attention_mask: torch.Tensor = None, # Prompt attention_mask
+ input_ids: torch.LongTensor = None, # Prompt input_ids
+ attention_mask: torch.Tensor = None, # Prompt attention_mask
max_text_new_tokens: int = 1024,
- max_speech_new_tokens: int = 1024, # Max length for speech tokens
- llm_kwargs: dict = None, # Kwargs for text LLM generate
- codec_lm_kwargs: dict = None # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
+ max_speech_new_tokens: int = 1024, # Max length for speech tokens
+ llm_kwargs: dict = None, # Kwargs for text LLM generate
+ codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
) -> Tuple[torch.LongTensor, List[List[int]]]:
"""
Generates text and corresponding speech tokens using the revised logic.
@@ -479,16 +543,22 @@ class SPEECH_LLM(nn.Module):
the generated speech codec tokens for a batch item.
"""
assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation."
- if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head:
- raise ValueError("codec_lm and associated layers must be initialized to generate speech output.")
+ if (
+ not self.codec_lm
+ or not self.speech_token_projector
+ or not self.codec_lm_head
+ ):
+ raise ValueError(
+ "codec_lm and associated layers must be initialized to generate speech output."
+ )
- device = next(self.parameters()).device # Use model's device
+ device = next(self.parameters()).device # Use model's device
batch_size = fbank.shape[0]
# --- 1. Prepare Prompt Embeddings ---
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
- speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
+ speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype
prompt_embeds = self.llm.get_input_embeddings()(input_ids)
@@ -511,12 +581,12 @@ class SPEECH_LLM(nn.Module):
"eos_token_id": self.llm.config.eos_token_id,
"pad_token_id": self.llm.config.pad_token_id,
"num_beams": 1,
- "do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
+ "do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed
"top_p": 0.5,
"top_k": 20,
"repetition_penalty": 1.1,
"temperature": 0.7,
- **(llm_kwargs or {}) # User-provided kwargs override defaults
+ **(llm_kwargs or {}), # User-provided kwargs override defaults
}
text_outputs = self.llm.generate(
@@ -525,17 +595,22 @@ class SPEECH_LLM(nn.Module):
max_new_tokens=max_text_new_tokens,
return_dict_in_generate=True,
output_hidden_states=True,
- **final_llm_kwargs
+ **final_llm_kwargs,
)
delay_step = 1
- generated_text_ids = text_outputs.sequences # [B, S_full]
+ generated_text_ids = text_outputs.sequences # [B, S_full]
eos_token_id = self.llm.config.eos_token_id
- eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D
- assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
+ eos_token_embedding = self.llm.get_input_embeddings()(
+ torch.tensor([[eos_token_id]], device=device)
+ ) # 1,D
+ assert (
+ generated_text_ids[0, -1] == eos_token_id
+ ), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
thinker_token_embeds_org = [
- token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
+ token_hidden_states[0].to(self.llm.device)
+ for token_hidden_states in text_outputs.hidden_states
]
- # shift one for thinker token_embeds, drop the first embeds, and add the eos token
+ # shift one for thinker token_embeds, drop the first embeds, and add the eos token
first_thinker_token_embed = torch.cat(
[
thinker_token_embeds_org[0][:, 1:],
@@ -544,19 +619,27 @@ class SPEECH_LLM(nn.Module):
dim=1,
)
- thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding]
+ thinker_token_embeds = (
+ [first_thinker_token_embed]
+ + thinker_token_embeds_org[2:]
+ + [eos_token_embedding]
+ )
thinker_hidden_states = [
- token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states
+ token_hidden_states[-1].to(self.llm.device)
+ for token_hidden_states in text_outputs.hidden_states
]
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
- thinker_reply_part = [torch.cat(
- [
- thinker_hidden_state,
- thinker_token_embed,
- ],
- dim=-1,
- )
- for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:])
+ thinker_reply_part = [
+ torch.cat(
+ [
+ thinker_hidden_state,
+ thinker_token_embed,
+ ],
+ dim=-1,
+ )
+ for thinker_hidden_state, thinker_token_embed in zip(
+ thinker_hidden_states[1:], thinker_token_embeds[1:]
+ )
]
thinker_reply_part = torch.cat(thinker_reply_part, dim=1)
# thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0]
@@ -568,26 +651,35 @@ class SPEECH_LLM(nn.Module):
dim=-1,
)
- thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec]
- thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec]
-
+ thinker_prompt_part = self.speech_token_projector(
+ thinker_prompt_part
+ ) # [B, S_full, D_codec]
+ thinker_reply_part = self.speech_token_projector(
+ thinker_reply_part
+ ) # [B, S_full, D_codec]
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full(
- (batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device
+ (batch_size, thinker_prompt_part_seq_len + delay_step + 1),
+ self.codec_lm.config.mask_token_id,
+ dtype=torch.long,
+ device=self.llm.device,
)
- talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id
- talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec]
+ talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
+ talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
+ talker_input_ids
+ ) # [B, S_full, D_codec]
thinker_input_embeds = torch.cat(
[
thinker_prompt_part,
- thinker_reply_part[:, :delay_step + 1, :],
+ thinker_reply_part[:, : delay_step + 1, :],
],
dim=1,
)
talker_inputs_embeds += thinker_input_embeds
- thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec]
-
+ thinker_reply_part = thinker_reply_part[
+ :, delay_step + 1 :, :
+ ] # [B, S_full, D_codec]
past_key_values = None
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
@@ -599,10 +691,14 @@ class SPEECH_LLM(nn.Module):
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
if t > 0:
- talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec]
+ talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
+ next_token_ids
+ ) # [B, 1, D_codec]
if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :]
- thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step
+ thinker_reply_part = thinker_reply_part[
+ :, 1:, :
+ ] # Remove the first token for next step
# # Add the projected text embedding corresponding to the current timestep `t`
# if t < text_context_len:
# # Text context from the full generated text sequence
@@ -611,20 +707,24 @@ class SPEECH_LLM(nn.Module):
# else:
# # No more text context to add
# inputs_embeds = current_speech_embeds
-
+
# Forward pass through codec LM for one step
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
codec_outputs = self.codec_lm(
- inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
+ inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
output_hidden_states=True,
# No attention mask needed here when using past_key_values and single token input
)
- last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here
+ last_token_hidden_state = codec_outputs.hidden_states[-1][
+ :, -1, :
+ ] # [B, D_codec] #TODO: check shape here
# Get logits for the *last* token generated in this step
- next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index
+ next_token_logits = self.codec_lm_head(
+ last_token_hidden_state
+ ) # Use -1 index
# suppress tokens between 4096:len(vocab)-3
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
next_token_ids = topk_sampling(
@@ -634,11 +734,14 @@ class SPEECH_LLM(nn.Module):
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
break
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
- past_key_values = codec_outputs.past_key_values # Update KV cache
- generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0])
+ past_key_values = codec_outputs.past_key_values # Update KV cache
+ generated_speech_tokens_list.append(
+ next_token_ids.squeeze(1).cpu().tolist()[0]
+ )
# --- 6. Return Results ---
return generated_text_ids, generated_speech_tokens_list
+
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
@@ -717,4 +820,4 @@ def top_k_top_p_filtering(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
- return logits
\ No newline at end of file
+ return logits
diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/speech_dataset.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/speech_dataset.py
similarity index 99%
rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/speech_dataset.py
rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/speech_dataset.py
index d0a77fd0e..43a4efb5a 100644
--- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/speech_dataset.py
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/speech_dataset.py
@@ -1,13 +1,12 @@
from typing import Callable, Dict, List, Union
import torch
-from torch.utils.data.dataloader import DataLoader, default_collate
-
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from lhotse.workarounds import Hdf5MemoryIssueFix
+from torch.utils.data.dataloader import DataLoader, default_collate
class K2SpeechRecognitionDataset(torch.utils.data.Dataset):
diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py
similarity index 100%
rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py
rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py
diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py
similarity index 60%
rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py
rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py
index 3155174fb..e33d2437d 100644
--- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py
+++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py
@@ -1,26 +1,25 @@
# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
import io
-
-import numpy as np
-import gradio as gr
-import soundfile as sf
-
-import gradio.processing_utils as processing_utils
-import tempfile
-from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
-from gradio_client import utils as client_utils
-
-from argparse import ArgumentParser
-import whisper
-import torch
-from peft import LoraConfig, get_peft_model
-from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
-from model import SPEECH_LLM, EncoderProjector
-from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
-import sherpa_onnx
-from cosyvoice.cli.cosyvoice import CosyVoice
import sys
-sys.path.append('/workspace/CosyVoice/third_party/Matcha-TTS')
+from argparse import ArgumentParser
+
+import gradio as gr
+import gradio.processing_utils as processing_utils
+import numpy as np
+import sherpa_onnx
+import soundfile as sf
+import torch
+import whisper
+from cosyvoice.cli.cosyvoice import CosyVoice
+from gradio_client import utils as client_utils
+from model import SPEECH_LLM, EncoderProjector
+from peft import LoraConfig, get_peft_model
+from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
+from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+
+# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def get_model(params, device="cuda"):
@@ -88,7 +87,7 @@ def get_model(params, device="cuda"):
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
- torch_dtype=torch.float16
+ torch_dtype=torch.float16,
)
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
@@ -102,12 +101,10 @@ def get_model(params, device="cuda"):
llm,
encoder_projector,
codec_lm,
- codec_lm_padding_side= "left" if params.use_flash_attn else "right",
+ codec_lm_padding_side="left" if params.use_flash_attn else "right",
)
- checkpoint = torch.load(
- f"{params.checkpoint_path}", map_location="cpu"
- )
+ checkpoint = torch.load(f"{params.checkpoint_path}", map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
model.to(device)
@@ -122,27 +119,37 @@ def audio_decode_cosyvoice(audio_tokens, codec_decoder):
Args:
audio_tokens (list): List of audio tokens to be processed.
codec_decoder: Codec decoder for generating audio.
-
+
Returns:
torch.Tensor: Generated audio waveform.
"""
- flow_embedding = codec_decoder.frontend.spk2info['中文女']['embedding']
+ flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"]
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
prompt_speech_feat = torch.zeros(1, 0, 80)
- tts_mel, _ = codec_decoder.model.flow.inference(token=audio_tokens.to(codec_decoder.model.device),
- token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
- prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
- prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
- prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
- prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(codec_decoder.model.device),
- embedding=flow_embedding.to(codec_decoder.model.device),
- flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),)
+ tts_mel, _ = codec_decoder.model.flow.inference(
+ token=audio_tokens.to(codec_decoder.model.device),
+ token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
+ codec_decoder.model.device
+ ),
+ prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device),
+ prompt_token_len=torch.tensor(
+ [flow_prompt_speech_token.shape[1]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ prompt_feat=prompt_speech_feat.to(codec_decoder.model.device),
+ prompt_feat_len=torch.tensor(
+ [prompt_speech_feat.shape[1]], dtype=torch.int32
+ ).to(codec_decoder.model.device),
+ embedding=flow_embedding.to(codec_decoder.model.device),
+ flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),
+ )
-
- audio_hat, _ = codec_decoder.model.hift.inference(speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0))
+ audio_hat, _ = codec_decoder.model.hift.inference(
+ speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+ )
return audio_hat
+
def preprocess(
messages,
tokenizer,
@@ -178,28 +185,14 @@ def preprocess(
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
-
-
-def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
+
+def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def format_history(history: list):
messages = []
for item in history:
if isinstance(item["content"], str):
- messages.append({"role": item['role'], "content": item['content']})
- # elif item["role"] == "user" and (isinstance(item["content"], list) or
- # isinstance(item["content"], tuple)):
- # file_path = item["content"][0]
- # # TODO: check if the file_path's transcript is already in the history
- # mime_type = client_utils.get_mimetype(file_path)
- # if mime_type.startswith("audio"):
- # messages.append({
- # "role":
- # item['role'],
- # "content": item["content"][1] # append audio transcript here
- # })
- print('predict history: ', messages)
- # messages = messages[-2:] # TODO: WAR: add history later
+ messages.append({"role": item["role"], "content": item["content"]})
return messages
def decode(
@@ -217,9 +210,8 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
dtype = torch.float32
device = model.llm.device
- feature = feature.to(device, dtype=dtype)#.transpose(1, 2)
- # assert feature.shape[2] == 80
-
+ feature = feature.to(device, dtype=dtype)
+
input_ids, attention_mask = preprocess([messages], tokenizer)
generated_ids, audio_tokens = model.decode_with_speech_output(
@@ -227,26 +219,21 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
- # print('hyps: ', hyps, 23333333333333333333333333)
+
yield {"type": "text", "data": hyps[0]}
- # yield {"type": "text", "data": hyps}
audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
- audio = audio_hat.squeeze(0).cpu().numpy()
- # sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050)
+ audio = audio_hat.squeeze(0).cpu().numpy()
audio = np.array(audio * 32767).astype(np.int16)
- # yield {"type": "audio", "data": (22050, audio)}
- # with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
- # sf.write(tmpfile.name, audio, 22050, format="WAV")
- # audio_path = tmpfile.name
wav_io = io.BytesIO()
sf.write(wav_io, audio, samplerate=22050, format="WAV")
wav_io.seek(0)
wav_bytes = wav_io.getvalue()
audio_path = processing_utils.save_bytes_to_cache(
- wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
+ wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE
+ )
yield {"type": "audio", "data": audio_path}
@@ -259,25 +246,27 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
gr.update(visible=True), # stop_btn
)
print(2333, history, audio)
- history.append({"role": "user", "content": (audio,)})
- history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
+ history.append({"role": "user", "content": (audio,)})
+ history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
history.append({"role": "assistant", "content": ""})
- formatted_history = format_history(history=history) # only keep string text format
+ formatted_history = format_history(
+ history=history
+ ) # only keep string text format
assert audio is not None
audio_transcript = get_transcript(
audio,
asr_model,
)
- print('audio_transcript: ', audio_transcript)
history[-2]["content"] = audio_transcript
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
fbank = fbank.unsqueeze(0)
assert fbank.ndim == 3
- # history.append({"role": "assistant", "content": ""})
- for chunk in decode(model, token2wav_model, tokenizer, fbank, formatted_history):
+ for chunk in decode(
+ model, token2wav_model, tokenizer, fbank, formatted_history
+ ):
if chunk["type"] == "text":
history[-1]["content"] = chunk["data"]
yield (
@@ -287,10 +276,9 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
gr.update(visible=True), # stop_btn
)
if chunk["type"] == "audio":
- history.append({
- "role": "assistant",
- "content": gr.Audio(chunk["data"])
- })
+ history.append(
+ {"role": "assistant", "content": gr.Audio(chunk["data"])}
+ )
# Final yield
yield (
@@ -304,8 +292,7 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
with gr.Tab("Online"):
with gr.Row():
with gr.Column(scale=1):
- microphone = gr.Audio(sources=['microphone'],
- type="filepath")
+ microphone = gr.Audio(sources=["microphone"], type="filepath")
submit_btn = gr.Button("Submit", variant="primary")
stop_btn = gr.Button("Stop", visible=False)
clear_btn = gr.Button("Clear History")
@@ -315,64 +302,80 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def clear_history():
return [], gr.update(value=None)
- submit_event = submit_btn.click(fn=media_predict,
- inputs=[
- microphone,
- media_chatbot,
- ],
- outputs=[
- microphone,
- media_chatbot, submit_btn,
- stop_btn
- ])
+ submit_event = submit_btn.click(
+ fn=media_predict,
+ inputs=[
+ microphone,
+ media_chatbot,
+ ],
+ outputs=[microphone, media_chatbot, submit_btn, stop_btn],
+ )
stop_btn.click(
- fn=lambda:
- (gr.update(visible=True), gr.update(visible=False)),
+ fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
inputs=None,
outputs=[submit_btn, stop_btn],
cancels=[submit_event],
- queue=False)
- clear_btn.click(fn=clear_history,
- inputs=None,
- outputs=[media_chatbot, microphone])
+ queue=False,
+ )
+ clear_btn.click(
+ fn=clear_history, inputs=None, outputs=[media_chatbot, microphone]
+ )
- demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100,
- ssr_mode=False,
- share=args.share,
- inbrowser=args.inbrowser,
- server_port=args.server_port,
- server_name=args.server_name,)
+ demo.queue(default_concurrency_limit=100, max_size=100).launch(
+ max_threads=100,
+ ssr_mode=False,
+ share=args.share,
+ inbrowser=args.inbrowser,
+ server_port=args.server_port,
+ server_name=args.server_name,
+ )
def _get_args():
parser = ArgumentParser()
- parser.add_argument('--checkpoint-path',
- type=str,
- default=None,
- help='Checkpoint name or path, default to %(default)r')
- parser.add_argument('--token2wav-path',
- type=str,
- default=None,
- help='Token2Wav path, default to %(default)r')
- parser.add_argument('--asr-model-dir',
- type=str,
- default=None,
- help='ASR model dir, default to %(default)r')
- parser.add_argument('--flash-attn2',
- action='store_true',
- default=False,
- help='Enable flash_attention_2 when loading the model.')
- parser.add_argument('--share',
- action='store_true',
- default=False,
- help='Create a publicly shareable link for the interface.')
- parser.add_argument('--inbrowser',
- action='store_true',
- default=False,
- help='Automatically launch the interface in a new tab on the default browser.')
- parser.add_argument('--server-port', type=int, default=8001, help='Demo server port.')
- parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
+ parser.add_argument(
+ "--checkpoint-path",
+ type=str,
+ default=None,
+ help="Checkpoint name or path, default to %(default)r",
+ )
+ parser.add_argument(
+ "--token2wav-path",
+ type=str,
+ default=None,
+ help="Token2Wav path, default to %(default)r",
+ )
+ parser.add_argument(
+ "--asr-model-dir",
+ type=str,
+ default=None,
+ help="ASR model dir, default to %(default)r",
+ )
+ parser.add_argument(
+ "--flash-attn2",
+ action="store_true",
+ default=False,
+ help="Enable flash_attention_2 when loading the model.",
+ )
+ parser.add_argument(
+ "--share",
+ action="store_true",
+ default=False,
+ help="Create a publicly shareable link for the interface.",
+ )
+ parser.add_argument(
+ "--inbrowser",
+ action="store_true",
+ default=False,
+ help="Automatically launch the interface in a new tab on the default browser.",
+ )
+ parser.add_argument(
+ "--server-port", type=int, default=8001, help="Demo server port."
+ )
+ parser.add_argument(
+ "--server-name", type=str, default="127.0.0.1", help="Demo server name."
+ )
add_model_arguments(parser)
args = parser.parse_args()
return args
@@ -401,6 +404,7 @@ def read_wave(wave_filename: str):
return samples_float32, sample_rate
+
def get_transcript(audio_path, recognizer):
samples, sample_rate = read_wave(audio_path)
s = recognizer.create_stream()
@@ -408,10 +412,13 @@ def get_transcript(audio_path, recognizer):
recognizer.decode_streams([s])
return s.result.text
+
if __name__ == "__main__":
args = _get_args()
model, tokenizer = get_model(args)
- token2wav = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False)
+ token2wav = CosyVoice(
+ args.token2wav_path, load_jit=False, load_trt=False, fp16=False
+ )
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=f"{args.asr_model_dir}/model.int8.onnx",
@@ -423,4 +430,4 @@ if __name__ == "__main__":
debug=False,
)
- _launch_demo(args, model, tokenizer, token2wav, asr_model)
\ No newline at end of file
+ _launch_demo(args, model, tokenizer, token2wav, asr_model)
diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/whisper_encoder_forward_monkey_patch.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/whisper_encoder_forward_monkey_patch.py
similarity index 100%
rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/whisper_encoder_forward_monkey_patch.py
rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/whisper_encoder_forward_monkey_patch.py