change prompt

This commit is contained in:
Yuekai Zhang 2024-06-07 10:14:28 +08:00
parent 68b99f456f
commit 40e4ac480c

View File

@ -418,13 +418,15 @@ def compute_loss(
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
target_ids[row, :col+4] = IGNORE_TOKEN_ID
mask_prompt = False
if mask_prompt:
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
target_ids[row, :col+4] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
@ -485,9 +487,13 @@ def compute_loss(
messages = []
for i, text in enumerate(texts):
# message = [
# {"role": "system", "content": "你是一个能处理音频的助手。"},
# {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": text},
# ]
message = [
{"role": "system", "content": "你是一个能处理音频的助手。"},
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text},
]
messages.append(message)