diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 7bdfc57dd..d8235e798 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -280,10 +280,14 @@ def decode_one_batch( feature_len = feature_len.to(device, dtype=dtype) messages = [[ - {"role": "system", "content": "你是一个能处理音频的助手。"}, - {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"}, + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "assistant", "content": ""}, ]] * len(feature) + # messages = [[ + # {"role": "system", "content": "你是一个能处理音频的助手。"}, + # {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"}, + # {"role": "assistant", "content": ""}, + # ]] * len(feature) input_ids, attention_mask = preprocess( messages, tokenizer, max_len=128