mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add history cache
This commit is contained in:
parent
47920c2336
commit
71a0a442a6
@ -107,4 +107,30 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "stage 6: "
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||
python3 ./slam_omni/web_demo.py \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-998.pt \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output True \
|
||||
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
|
||||
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "stage 7: "
|
||||
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
|
||||
|
||||
if [ ! -d $model_path ]; then
|
||||
pip install sherpa-onnx
|
||||
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
|
||||
fi
|
||||
fi
|
@ -630,7 +630,7 @@ class SPEECH_LLM(nn.Module):
|
||||
next_token_ids = topk_sampling(
|
||||
next_token_logits,
|
||||
)
|
||||
print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
|
||||
# print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
|
||||
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
||||
break
|
||||
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
||||
|
@ -6,7 +6,7 @@ import gradio as gr
|
||||
import soundfile as sf
|
||||
|
||||
import gradio.processing_utils as processing_utils
|
||||
|
||||
import tempfile
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||
from gradio_client import utils as client_utils
|
||||
|
||||
@ -17,8 +17,11 @@ from peft import LoraConfig, get_peft_model
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
|
||||
import sherpa_onnx
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
import sys
|
||||
sys.path.append('/workspace/CosyVoice/third_party/Matcha-TTS')
|
||||
|
||||
|
||||
def get_model(params, device="cuda"):
|
||||
"""Load and prepare the speech-to-speech model."""
|
||||
@ -177,26 +180,26 @@ def preprocess(
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def _launch_demo(args, model, tokenizer, token2wav_model):
|
||||
def _launch_demo(args, model, tokenizer, token2wav_model, asr_model):
|
||||
|
||||
def format_history(history: list):
|
||||
messages = []
|
||||
for item in history:
|
||||
if isinstance(item["content"], str):
|
||||
messages.append({"role": item['role'], "content": item['content']})
|
||||
elif item["role"] == "user" and (isinstance(item["content"], list) or
|
||||
isinstance(item["content"], tuple)):
|
||||
file_path = item["content"][0]
|
||||
# TODO: check if the file_path's transcript is already in the history
|
||||
mime_type = client_utils.get_mimetype(file_path)
|
||||
if mime_type.startswith("audio"):
|
||||
messages.append({
|
||||
"role":
|
||||
item['role'],
|
||||
"content": item["content"][1] # append audio transcript here
|
||||
})
|
||||
# elif item["role"] == "user" and (isinstance(item["content"], list) or
|
||||
# isinstance(item["content"], tuple)):
|
||||
# file_path = item["content"][0]
|
||||
# # TODO: check if the file_path's transcript is already in the history
|
||||
# mime_type = client_utils.get_mimetype(file_path)
|
||||
# if mime_type.startswith("audio"):
|
||||
# messages.append({
|
||||
# "role":
|
||||
# item['role'],
|
||||
# "content": item["content"][1] # append audio transcript here
|
||||
# })
|
||||
print('predict history: ', messages)
|
||||
messages = messages[-2:] # TODO: WAR: add history later
|
||||
# messages = messages[-2:] # TODO: WAR: add history later
|
||||
return messages
|
||||
|
||||
def decode(
|
||||
@ -214,8 +217,8 @@ def _launch_demo(args, model, tokenizer, token2wav_model):
|
||||
dtype = torch.float32
|
||||
device = model.llm.device
|
||||
|
||||
feature = feature.to(device, dtype=dtype).transpose(1, 2)
|
||||
assert feature.shape[2] == 80
|
||||
feature = feature.to(device, dtype=dtype)#.transpose(1, 2)
|
||||
# assert feature.shape[2] == 80
|
||||
|
||||
input_ids, attention_mask = preprocess([messages], tokenizer)
|
||||
|
||||
@ -224,7 +227,9 @@ def _launch_demo(args, model, tokenizer, token2wav_model):
|
||||
)
|
||||
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
yield {"type": "text", "data": hyps}
|
||||
# print('hyps: ', hyps, 23333333333333333333333333)
|
||||
yield {"type": "text", "data": hyps[0]}
|
||||
# yield {"type": "text", "data": hyps}
|
||||
|
||||
audio_tokens = [token for token in audio_tokens if token < 4096]
|
||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||
@ -232,6 +237,10 @@ def _launch_demo(args, model, tokenizer, token2wav_model):
|
||||
audio = audio_hat.squeeze(0).cpu().numpy()
|
||||
# sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050)
|
||||
audio = np.array(audio * 32767).astype(np.int16)
|
||||
# yield {"type": "audio", "data": (22050, audio)}
|
||||
# with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
|
||||
# sf.write(tmpfile.name, audio, 22050, format="WAV")
|
||||
# audio_path = tmpfile.name
|
||||
wav_io = io.BytesIO()
|
||||
sf.write(wav_io, audio, samplerate=22050, format="WAV")
|
||||
wav_io.seek(0)
|
||||
@ -249,18 +258,22 @@ def _launch_demo(args, model, tokenizer, token2wav_model):
|
||||
gr.update(visible=False), # submit_btn
|
||||
gr.update(visible=True), # stop_btn
|
||||
)
|
||||
|
||||
assert audio is not None
|
||||
# get audio transcript here
|
||||
print(2333, history, audio)
|
||||
history.append({"role": "user", "content": (audio,)})
|
||||
history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"})
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
formatted_history = format_history(history=history)
|
||||
formatted_history = format_history(history=history) # only keep string text format
|
||||
|
||||
# audio_transcript = get_audio_transcript(audio)
|
||||
audio_transcript = "audio transcript"
|
||||
history[-2]["content"] = (audio, audio_transcript)
|
||||
fbank = whisper.log_mel_spectrogram(audio, model.llm.device)
|
||||
print('fbank: ', fbank.shape)
|
||||
assert audio is not None
|
||||
audio_transcript = get_transcript(
|
||||
audio,
|
||||
asr_model,
|
||||
)
|
||||
print('audio_transcript: ', audio_transcript)
|
||||
history[-2]["content"] = audio_transcript
|
||||
|
||||
fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device)
|
||||
fbank = fbank.unsqueeze(0)
|
||||
assert fbank.ndim == 3
|
||||
|
||||
# history.append({"role": "assistant", "content": ""})
|
||||
@ -342,6 +355,10 @@ def _get_args():
|
||||
type=str,
|
||||
default=None,
|
||||
help='Token2Wav path, default to %(default)r')
|
||||
parser.add_argument('--asr-model-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='ASR model dir, default to %(default)r')
|
||||
parser.add_argument('--flash-attn2',
|
||||
action='store_true',
|
||||
default=False,
|
||||
@ -354,14 +371,56 @@ def _get_args():
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Automatically launch the interface in a new tab on the default browser.')
|
||||
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
|
||||
parser.add_argument('--server-port', type=int, default=8001, help='Demo server port.')
|
||||
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
|
||||
add_model_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def read_wave(wave_filename: str):
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and can be of type
|
||||
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples,
|
||||
which are normalized to the range [-1, 1].
|
||||
- Sample rate of the wave file.
|
||||
"""
|
||||
|
||||
samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
||||
assert (
|
||||
samples.ndim == 1
|
||||
), f"Expected single channel, but got {samples.ndim} channels."
|
||||
|
||||
samples_float32 = samples.astype(np.float32)
|
||||
|
||||
return samples_float32, sample_rate
|
||||
|
||||
def get_transcript(audio_path, recognizer):
|
||||
samples, sample_rate = read_wave(audio_path)
|
||||
s = recognizer.create_stream()
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
recognizer.decode_streams([s])
|
||||
return s.result.text
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _get_args()
|
||||
model, tokenizer = get_model(args)
|
||||
cosyvoice = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False)
|
||||
_launch_demo(args, model, tokenizer, cosyvoice)
|
||||
token2wav = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False)
|
||||
|
||||
asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=f"{args.asr_model_dir}/model.int8.onnx",
|
||||
tokens=f"{args.asr_model_dir}/tokens.txt",
|
||||
num_threads=2,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
decoding_method="greedy_search",
|
||||
debug=False,
|
||||
)
|
||||
|
||||
_launch_demo(args, model, tokenizer, token2wav, asr_model)
|
Loading…
x
Reference in New Issue
Block a user