add history cache

This commit is contained in:
root 2025-04-25 10:05:07 +00:00
parent 47920c2336
commit 71a0a442a6
3 changed files with 116 additions and 31 deletions

View File

@ -107,4 +107,30 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "stage 6: "
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
exp_dir=./slam_omni/exp_speech2speech_rerun
python3 ./slam_omni/web_demo.py \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-998.pt \
--use-flash-attn True \
--enable-speech-output True \
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "stage 7: "
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
pip install sherpa-onnx
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi
fi

View File

@ -630,7 +630,7 @@ class SPEECH_LLM(nn.Module):
next_token_ids = topk_sampling(
next_token_logits,
)
print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
# print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
break
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step

View File

@ -6,7 +6,7 @@ import gradio as gr
import soundfile as sf
import gradio.processing_utils as processing_utils
import tempfile
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from gradio_client import utils as client_utils
@ -17,8 +17,11 @@ from peft import LoraConfig, get_peft_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from model import SPEECH_LLM, EncoderProjector
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
import sherpa_onnx
from cosyvoice.cli.cosyvoice import CosyVoice
import sys
sys.path.append('/workspace/CosyVoice/third_party/Matcha-TTS')
def get_model(params, device="cuda"):
"""Load and prepare the speech-to-speech model."""
@ -177,26 +180,26 @@ def preprocess(
return input_ids, attention_mask
def _launch_demo(args, model, tokenizer, token2wav_model):
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
def format_history(history: list):
messages = []
for item in history:
if isinstance(item["content"], str):
messages.append({"role": item['role'], "content": item['content']})
elif item["role"] == "user" and (isinstance(item["content"], list) or
isinstance(item["content"], tuple)):
file_path = item["content"][0]
# TODO: check if the file_path's transcript is already in the history
mime_type = client_utils.get_mimetype(file_path)
if mime_type.startswith("audio"):
messages.append({
"role":
item['role'],
"content": item["content"][1] # append audio transcript here
})
# elif item["role"] == "user" and (isinstance(item["content"], list) or
# isinstance(item["content"], tuple)):
# file_path = item["content"][0]
# # TODO: check if the file_path's transcript is already in the history
# mime_type = client_utils.get_mimetype(file_path)
# if mime_type.startswith("audio"):
# messages.append({
# "role":
# item['role'],
# "content": item["content"][1] # append audio transcript here
# })
print('predict history: ', messages)
messages = messages[-2:] # TODO: WAR: add history later
# messages = messages[-2:] # TODO: WAR: add history later
return messages
def decode(
@ -214,8 +217,8 @@ def _launch_demo(args, model, tokenizer, token2wav_model):
dtype = torch.float32
device = model.llm.device
feature = feature.to(device, dtype=dtype).transpose(1, 2)
assert feature.shape[2] == 80
feature = feature.to(device, dtype=dtype)#.transpose(1, 2)
# assert feature.shape[2] == 80
input_ids, attention_mask = preprocess([messages], tokenizer)
@ -224,7 +227,9 @@ def _launch_demo(args, model, tokenizer, token2wav_model):
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
yield {"type": "text", "data": hyps}
# print('hyps: ', hyps, 23333333333333333333333333)
yield {"type": "text", "data": hyps[0]}
# yield {"type": "text", "data": hyps}
audio_tokens = [token for token in audio_tokens if token < 4096]
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
@ -232,6 +237,10 @@ def _launch_demo(args, model, tokenizer, token2wav_model):
audio = audio_hat.squeeze(0).cpu().numpy()
# sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050)
audio = np.array(audio * 32767).astype(np.int16)
# yield {"type": "audio", "data": (22050, audio)}
# with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
# sf.write(tmpfile.name, audio, 22050, format="WAV")
# audio_path = tmpfile.name
wav_io = io.BytesIO()
sf.write(wav_io, audio, samplerate=22050, format="WAV")
wav_io.seek(0)
@ -249,18 +258,22 @@ def _launch_demo(args, model, tokenizer, token2wav_model):
gr.update(visible=False), # submit_btn
gr.update(visible=True), # stop_btn
)
assert audio is not None
# get audio transcript here
print(2333, history, audio)
history.append({"role": "user", "content": (audio,)})
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
history.append({"role": "assistant", "content": ""})
formatted_history = format_history(history=history)
formatted_history = format_history(history=history) # only keep string text format
# audio_transcript = get_audio_transcript(audio)
audio_transcript = "audio transcript"
history[-2]["content"] = (audio, audio_transcript)
fbank = whisper.log_mel_spectrogram(audio, model.llm.device)
print('fbank: ', fbank.shape)
assert audio is not None
audio_transcript = get_transcript(
audio,
asr_model,
)
print('audio_transcript: ', audio_transcript)
history[-2]["content"] = audio_transcript
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
fbank = fbank.unsqueeze(0)
assert fbank.ndim == 3
# history.append({"role": "assistant", "content": ""})
@ -342,6 +355,10 @@ def _get_args():
type=str,
default=None,
help='Token2Wav path, default to %(default)r')
parser.add_argument('--asr-model-dir',
type=str,
default=None,
help='ASR model dir, default to %(default)r')
parser.add_argument('--flash-attn2',
action='store_true',
default=False,
@ -354,14 +371,56 @@ def _get_args():
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
parser.add_argument('--server-port', type=int, default=8001, help='Demo server port.')
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
add_model_arguments(parser)
args = parser.parse_args()
return args
def read_wave(wave_filename: str):
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and can be of type
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples,
which are normalized to the range [-1, 1].
- Sample rate of the wave file.
"""
samples, sample_rate = sf.read(wave_filename, dtype="float32")
assert (
samples.ndim == 1
), f"Expected single channel, but got {samples.ndim} channels."
samples_float32 = samples.astype(np.float32)
return samples_float32, sample_rate
def get_transcript(audio_path, recognizer):
samples, sample_rate = read_wave(audio_path)
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
recognizer.decode_streams([s])
return s.result.text
if __name__ == "__main__":
args = _get_args()
model, tokenizer = get_model(args)
cosyvoice = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False)
_launch_demo(args, model, tokenizer, cosyvoice)
token2wav = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False)
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=f"{args.asr_model_dir}/model.int8.onnx",
tokens=f"{args.asr_model_dir}/tokens.txt",
num_threads=2,
sample_rate=16000,
feature_dim=80,
decoding_method="greedy_search",
debug=False,
)
_launch_demo(args, model, tokenizer, token2wav, asr_model)