From 40e4ac480c81aa9ae7bdcc708f0474aaa0b5bf92 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 7 Jun 2024 10:14:28 +0800 Subject: [PATCH] change prompt --- .../ASR_LLM/whisper_llm_zh/train.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 3272ce7f3..941f49081 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -418,13 +418,15 @@ def compute_loss( target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID # first get the indices of the tokens - mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)) - # then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644 - # target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID - for i in range(mask_indices[0].size(0)): - row = mask_indices[0][i] - col = mask_indices[1][i] - target_ids[row, :col+4] = IGNORE_TOKEN_ID + mask_prompt = False + if mask_prompt: + mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)) + # then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644 + # target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID + for i in range(mask_indices[0].size(0)): + row = mask_indices[0][i] + col = mask_indices[1][i] + target_ids[row, :col+4] = IGNORE_TOKEN_ID attention_mask = input_ids.ne(tokenizer.pad_token_id) @@ -485,9 +487,13 @@ def compute_loss( messages = [] for i, text in enumerate(texts): + # message = [ + # {"role": "system", "content": "你是一个能处理音频的助手。"}, + # {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"}, + # {"role": "assistant", "content": text}, + # ] message = [ - {"role": "system", "content": "你是一个能处理音频的助手。"}, - {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"}, + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "assistant", "content": text}, ] messages.append(message)