remove unsed

This commit is contained in:
Yuekai Zhang 2025-04-25 14:21:50 +08:00
parent 6ea7ec8543
commit 9a07363a8d

View File

@ -1,26 +1,18 @@
# Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py # Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py
import io import io
import os
import ffmpeg
import numpy as np import numpy as np
import gradio as gr import gradio as gr
import soundfile as sf import soundfile as sf
#import modelscope_studio.components.base as ms
#import modelscope_studio.components.antd as antd
import gradio.processing_utils as processing_utils import gradio.processing_utils as processing_utils
#from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor from transformers import AutoModelForCausalLM
from gradio_client import utils as client_utils from gradio_client import utils as client_utils
#from qwen_omni_utils import process_mm_info
from argparse import ArgumentParser from argparse import ArgumentParser
def _load_model_processor(args): def _load_model_processor(args):
if args.cpu_only:
device_map = 'cpu'
else:
device_map = 'auto'
# Check if flash-attn2 flag is enabled and load model accordingly # Check if flash-attn2 flag is enabled and load model accordingly
if args.flash_attn2: if args.flash_attn2:
@ -35,37 +27,9 @@ def _load_model_processor(args):
return model, processor return model, processor
def _launch_demo(args, model, processor): def _launch_demo(args, model, processor):
# Voice settings
VOICE_LIST = ['Chelsie', 'Ethan']
DEFAULT_VOICE = 'Chelsie'
default_system_prompt = 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.' def format_history(history: list):
language = args.ui_language
# def get_text(text: str, cn_text: str):
# if language == 'en':
# return text
# if language == 'zh':
# return cn_text
# return text
# def convert_webm_to_mp4(input_file, output_file):
# try:
# (
# ffmpeg
# .input(input_file)
# .output(output_file, acodec='aac', ar='16000', audio_bitrate='192k')
# .run(quiet=True, overwrite_output=True)
# )
# print(f"Conversion successful: {output_file}")
# except ffmpeg.Error as e:
# print("An error occurred during conversion.")
# print(e.stderr.decode('utf-8'))
def format_history(history: list, system_prompt: str):
messages = [] messages = []
# messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
for item in history: for item in history:
if isinstance(item["content"], str): if isinstance(item["content"], str):
messages.append({"role": item['role'], "content": item['content']}) messages.append({"role": item['role'], "content": item['content']})
@ -74,25 +38,7 @@ def _launch_demo(args, model, processor):
file_path = item["content"][0] file_path = item["content"][0]
mime_type = client_utils.get_mimetype(file_path) mime_type = client_utils.get_mimetype(file_path)
if mime_type.startswith("image"): if mime_type.startswith("audio"):
messages.append({
"role":
item['role'],
"content": [{
"type": "image",
"image": file_path
}]
})
elif mime_type.startswith("video"):
messages.append({
"role":
item['role'],
"content": [{
"type": "video",
"video": file_path
}]
})
elif mime_type.startswith("audio"):
messages.append({ messages.append({
"role": "role":
item['role'], item['role'],
@ -103,17 +49,17 @@ def _launch_demo(args, model, processor):
}) })
return messages return messages
def predict(messages, voice=DEFAULT_VOICE): def predict(messages):
print('predict history: ', messages) print('predict history: ', messages)
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(messages, use_audio_in_video=True) audios = [msg['content'][0]['audio'] for msg in messages if msg['role'] == 'user' and isinstance(msg['content'], list) and msg['content'][0]['type'] == 'audio']
inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=True) inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
inputs = inputs.to(model.device).to(model.dtype) inputs = inputs.to(model.device).to(model.dtype)
text_ids, audio = model.generate(**inputs, speaker=voice, use_audio_in_video=True) text_ids, audio = model.generate(**inputs)
response = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) response = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
response = response[0].split("\n")[-1] response = response[0].split("\n")[-1]
@ -128,37 +74,31 @@ def _launch_demo(args, model, processor):
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE) wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE)
yield {"type": "audio", "data": audio_path} yield {"type": "audio", "data": audio_path}
def media_predict(audio, video, history, system_prompt, voice_choice): def media_predict(audio, history):
# First yield # First yield
yield ( yield (
None, # microphone None, # microphone
None, # webcam
history, # media_chatbot history, # media_chatbot
gr.update(visible=False), # submit_btn gr.update(visible=False), # submit_btn
gr.update(visible=True), # stop_btn gr.update(visible=True), # stop_btn
) )
if video is not None: files = [audio]
convert_webm_to_mp4(video, video.replace('.webm', '.mp4'))
video = video.replace(".webm", ".mp4")
files = [audio, video]
for f in files: for f in files:
if f: if f:
history.append({"role": "user", "content": (f, )}) history.append({"role": "user", "content": (f, )})
formatted_history = format_history(history=history, formatted_history = format_history(history=history)
system_prompt=system_prompt,)
history.append({"role": "assistant", "content": ""}) history.append({"role": "assistant", "content": ""})
for chunk in predict(formatted_history, voice_choice): for chunk in predict(formatted_history):
if chunk["type"] == "text": if chunk["type"] == "text":
history[-1]["content"] = chunk["data"] history[-1]["content"] = chunk["data"]
yield ( yield (
None, # microphone None, # microphone
None, # webcam
history, # media_chatbot history, # media_chatbot
gr.update(visible=False), # submit_btn gr.update(visible=False), # submit_btn
gr.update(visible=True), # stop_btn gr.update(visible=True), # stop_btn
@ -172,79 +112,47 @@ def _launch_demo(args, model, processor):
# Final yield # Final yield
yield ( yield (
None, # microphone None, # microphone
None, # webcam
history, # media_chatbot history, # media_chatbot
gr.update(visible=True), # submit_btn gr.update(visible=True), # submit_btn
gr.update(visible=False), # stop_btn gr.update(visible=False), # stop_btn
) )
with gr.Blocks() as demo, ms.Application(), antd.ConfigProvider(): with gr.Blocks() as demo:
with gr.Sidebar(open=False): with gr.Tab("Online"):
system_prompt_textbox = gr.Textbox(label="System Prompt", with gr.Row():
value=default_system_prompt) with gr.Column(scale=1):
with antd.Flex(gap="small", justify="center", align="center"): microphone = gr.Audio(sources=['microphone'],
with antd.Flex(vertical=True, gap="small", align="center"): type="filepath")
antd.Typography.Title("Qwen2.5-Omni Demo", submit_btn = gr.Button(get_text("Submit", "提交"),
level=1, variant="primary")
elem_style=dict(margin=0, fontSize=28)) stop_btn = gr.Button(get_text("Stop", "停止"), visible=False)
with antd.Flex(vertical=True, gap="small"): clear_btn = gr.Button(get_text("Clear History", "清除历史"))
antd.Typography.Text(get_text("🎯 Instructions for use:", with gr.Column(scale=2):
"🎯 使用说明:"), media_chatbot = gr.Chatbot(height=650, type="messages")
strong=True)
antd.Typography.Text(
get_text(
"1⃣ Click the Audio Record button or the Camera Record button.",
"1⃣ 点击音频录制按钮,或摄像头-录制按钮"))
antd.Typography.Text(
get_text("2⃣ Input audio or video.", "2⃣ 输入音频或者视频"))
antd.Typography.Text(
get_text(
"3⃣ Click the submit button and wait for the model's response.",
"3⃣ 点击提交并等待模型的回答"))
voice_choice = gr.Dropdown(label="Voice Choice",
choices=VOICE_LIST,
value=DEFAULT_VOICE)
with gr.Tabs():
with gr.Tab("Online"):
with gr.Row():
with gr.Column(scale=1):
microphone = gr.Audio(sources=['microphone'],
type="filepath")
webcam = gr.Video(sources=['webcam'],
height=400,
include_audio=True)
submit_btn = gr.Button(get_text("Submit", "提交"),
variant="primary")
stop_btn = gr.Button(get_text("Stop", "停止"), visible=False)
clear_btn = gr.Button(get_text("Clear History", "清除历史"))
with gr.Column(scale=2):
media_chatbot = gr.Chatbot(height=650, type="messages")
def clear_history(): def clear_history():
return [], gr.update(value=None), gr.update(value=None) return [], gr.update(value=None)
submit_event = submit_btn.click(fn=media_predict, submit_event = submit_btn.click(fn=media_predict,
inputs=[ inputs=[
microphone, webcam, microphone,
media_chatbot, media_chatbot,
system_prompt_textbox, ],
voice_choice outputs=[
], microphone,
outputs=[ media_chatbot, submit_btn,
microphone, webcam, stop_btn
media_chatbot, submit_btn, ])
stop_btn stop_btn.click(
]) fn=lambda:
stop_btn.click( (gr.update(visible=True), gr.update(visible=False)),
fn=lambda: inputs=None,
(gr.update(visible=True), gr.update(visible=False)), outputs=[submit_btn, stop_btn],
inputs=None, cancels=[submit_event],
outputs=[submit_btn, stop_btn], queue=False)
cancels=[submit_event], clear_btn.click(fn=clear_history,
queue=False) inputs=None,
clear_btn.click(fn=clear_history, outputs=[media_chatbot, microphone])
inputs=None,
outputs=[media_chatbot, microphone, webcam])
demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100, demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100,
ssr_mode=False, ssr_mode=False,
@ -254,16 +162,13 @@ def _launch_demo(args, model, processor):
server_name=args.server_name,) server_name=args.server_name,)
DEFAULT_CKPT_PATH = "Qwen/Qwen2.5-Omni-7B"
def _get_args(): def _get_args():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('-c', parser.add_argument('--checkpoint-path',
'--checkpoint-path',
type=str, type=str,
default=DEFAULT_CKPT_PATH, default=None,
help='Checkpoint name or path, default to %(default)r') help='Checkpoint name or path, default to %(default)r')
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
parser.add_argument('--flash-attn2', parser.add_argument('--flash-attn2',
action='store_true', action='store_true',
@ -279,7 +184,6 @@ def _get_args():
help='Automatically launch the interface in a new tab on the default browser.') help='Automatically launch the interface in a new tab on the default browser.')
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.') parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.') parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
parser.add_argument('--ui-language', type=str, choices=['en', 'zh'], default='en', help='Display language for the UI.')
args = parser.parse_args() args = parser.parse_args()
return args return args