mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
change prompt
This commit is contained in:
parent
68b99f456f
commit
40e4ac480c
@ -418,13 +418,15 @@ def compute_loss(
|
|||||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||||
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
||||||
# first get the indices of the tokens
|
# first get the indices of the tokens
|
||||||
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
|
mask_prompt = False
|
||||||
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
|
if mask_prompt:
|
||||||
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
|
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
|
||||||
for i in range(mask_indices[0].size(0)):
|
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
|
||||||
row = mask_indices[0][i]
|
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
|
||||||
col = mask_indices[1][i]
|
for i in range(mask_indices[0].size(0)):
|
||||||
target_ids[row, :col+4] = IGNORE_TOKEN_ID
|
row = mask_indices[0][i]
|
||||||
|
col = mask_indices[1][i]
|
||||||
|
target_ids[row, :col+4] = IGNORE_TOKEN_ID
|
||||||
|
|
||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
@ -485,9 +487,13 @@ def compute_loss(
|
|||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
|
# message = [
|
||||||
|
# {"role": "system", "content": "你是一个能处理音频的助手。"},
|
||||||
|
# {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
||||||
|
# {"role": "assistant", "content": text},
|
||||||
|
# ]
|
||||||
message = [
|
message = [
|
||||||
{"role": "system", "content": "你是一个能处理音频的助手。"},
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||||
{"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
|
||||||
{"role": "assistant", "content": text},
|
{"role": "assistant", "content": text},
|
||||||
]
|
]
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user