change prompt

This commit is contained in:
Yuekai Zhang 2024-06-07 10:14:28 +08:00
parent 68b99f456f
commit 40e4ac480c

View File

@ -418,6 +418,8 @@ def compute_loss(
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens # first get the indices of the tokens
mask_prompt = False
if mask_prompt:
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)) mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644 # then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID # target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
@ -485,9 +487,13 @@ def compute_loss(
messages = [] messages = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
# message = [
# {"role": "system", "content": "你是一个能处理音频的助手。"},
# {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": text},
# ]
message = [ message = [
{"role": "system", "content": "你是一个能处理音频的助手。"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": text}, {"role": "assistant", "content": text},
] ]
messages.append(message) messages.append(message)