diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index e23f26684..75bd9c576 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -52,12 +52,28 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "stage 2: " python3 ./slam_omni/decode.py \ --max-duration 80 \ - --exp-dir slam_omni/exp_test_whisper_qwen2_1.5B \ + --exp-dir slam_omni/exp_speech2text \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ - --llm-path-or-name models/qwen \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --epoch 999 --avg 1 \ --manifest-dir data/fbank \ --use-flash-attn True \ + --method pure_text_sampling \ + --use-lora True # --on-the-fly-feats True + +fi + +if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then + log "stage 2: " + python3 ./slam_omni/decode.py \ + --max-duration 80 \ + --exp-dir slam_omni/exp_speech2text \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --epoch 999 --avg 1 \ + --manifest-dir data/fbank \ + --use-flash-attn True \ + --method pure_text_sampling_original_0.5B \ --use-lora False # --on-the-fly-feats True fi diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py index f878d32e7..3feef8f1c 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py @@ -52,7 +52,6 @@ from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -import k2 import torch import torch.nn as nn import transformers @@ -60,13 +59,12 @@ import whisper from data_module import AsrDataModule from lhotse.cut import Cut from model import SPEECH_LLM, EncoderProjector -# from data_module import MultiDataset + from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from train import DEFAULT_SPEECH_TOKEN from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.env import get_env_info from icefall.utils import ( AttributeDict, @@ -76,7 +74,6 @@ from icefall.utils import ( write_error_stats, ) - def average_checkpoints( filenames: List[Path], device: torch.device = torch.device("cpu") ) -> dict: @@ -133,7 +130,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--llm-path-or-name", type=str, - default="/workspace/asr/Qwen1.5-0.5B-Chat", + default="", help="Path or name of the large language model.", ) @@ -264,7 +261,6 @@ def decode_one_batch( def preprocess( messages, tokenizer: transformers.PreTrainedTokenizer, - max_len: int = 128, ) -> Dict: """Preprocesses the data for supervised fine-tuning.""" texts = [] @@ -277,8 +273,7 @@ def decode_one_batch( add_generation_prompt=False, chat_template=TEMPLATE, padding="longest", - max_length=max_len, - truncation=True, + truncation=False, ) ) max_len_texts = max([len(text) for text in texts]) @@ -318,18 +313,38 @@ def decode_one_batch( 2, ) - # supervisions = batch["supervisions"] - # feature_len = supervisions["num_frames"] - # feature_len = feature_len.to(device, dtype=dtype) + chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]] - messages = [ - [ - {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, - {"role": "assistant", "content": ""}, + # messages = [ + # [ + # {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, + # {"role": "assistant", "content": ""}, + # ] + # ] * len(feature) + questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] + history_contexts = [question.rsplit(':', 1)[0].strip() for question in questions_with_history] + last_questions = [question.split(': ')[-1].strip() for question in questions_with_history] + messages = [] + for i, total_round in enumerate(chat_rounds): + message = [] + if total_round > 1: + history_question_answer = history_contexts[i].split('USER:') + history_question_answer = [item for item in history_question_answer if item] + for j in range(total_round - 1): + # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。 + question_answer = history_question_answer[j].split('ASSISTANT:') + message += [ + {"role": "user", "content": question_answer[0].strip()}, + {"role": "assistant", "content": question_answer[1].strip()} + ] + message += [ + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, + # {"role": "user", "content": f"{last_questions[i]}"}, + {"role": "assistant", "content": ""} ] - ] * len(feature) + messages.append(message) - input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128) + input_ids, attention_mask = preprocess(messages, tokenizer) generated_ids = model.decode( feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) @@ -422,7 +437,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_text = normalize_text_alimeeting(ref_text) + # ref_text = normalize_text_alimeeting(ref_text) ref_words = ref_text.split() print(f"ref: {ref_text}") print(f"hyp: {''.join(hyp_words)}") @@ -449,7 +464,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.log_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -459,7 +474,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -475,7 +490,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + errs_info = params.log_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -499,8 +514,10 @@ def main(): params = get_params() params.update(vars(args)) params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.log_dir = Path(params.exp_dir) / f"log-{params.method}" + params.log_dir.mkdir(parents=True, exist_ok=True) setup_logger( - f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}" ) logging.info("Decoding started") diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py index 829ef4e2d..5126a5d34 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py @@ -241,27 +241,41 @@ class SPEECH_LLM(nn.Module): inputs_embeds = self.llm.get_input_embeddings()(input_ids) ( inputs_embeds, - attention_mask, _, - position_ids, + _, + _, ) = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask ) generated_ids = self.llm.generate( inputs_embeds=inputs_embeds, - max_new_tokens=kwargs.get("max_new_tokens", 200), + max_new_tokens=kwargs.get("max_new_tokens", 1024), num_beams=kwargs.get("num_beams", 1), - do_sample=kwargs.get("do_sample", False), + do_sample=kwargs.get("do_sample", True), min_length=kwargs.get("min_length", 1), - top_p=kwargs.get("top_p", 1.0), - repetition_penalty=kwargs.get("repetition_penalty", 1.0), - length_penalty=kwargs.get("length_penalty", 1.0), - temperature=kwargs.get("temperature", 1.0), + top_p=kwargs.get("top_p", 0.5), + top_k=kwargs.get("top_k", 20), + repetition_penalty=kwargs.get("repetition_penalty", 1.1), + temperature=kwargs.get("temperature", 0.7), bos_token_id=self.llm.config.bos_token_id, eos_token_id=self.llm.config.eos_token_id, pad_token_id=self.llm.config.pad_token_id, ) + # generated_ids = self.llm.generate( + # inputs_embeds=inputs_embeds, + # max_new_tokens=kwargs.get("max_new_tokens", 200), + # num_beams=kwargs.get("num_beams", 1), + # do_sample=kwargs.get("do_sample", False), + # min_length=kwargs.get("min_length", 1), + # top_p=kwargs.get("top_p", 1.0), + # repetition_penalty=kwargs.get("repetition_penalty", 1.0), + # temperature=kwargs.get("temperature", 1.0), + # length_penalty=kwargs.get("length_penalty", 1.0), + # bos_token_id=self.llm.config.bos_token_id, + # eos_token_id=self.llm.config.eos_token_id, + # pad_token_id=self.llm.config.pad_token_id, + # ) return generated_ids