fix multi rounds data

This commit is contained in:
Yuekai Zhang 2025-04-14 14:32:42 +08:00
parent 202d764cfb
commit 1d11662016

View File

@ -303,6 +303,7 @@ def compute_loss(
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
print(msg,23333333333333)
texts.append(
tokenizer.apply_chat_template(
msg,
@ -334,9 +335,14 @@ def compute_loss(
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
default_speech_token_indices = torch.where(
input_ids == default_speech_token_id
)
mask_indices = torch.where(
input_ids == tokenizer.convert_tokens_to_ids("assistant")
)
print(mask_indices, default_speech_token_indices, default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
@ -362,6 +368,7 @@ def compute_loss(
answers = batch["supervisions"]["text"]
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
@ -369,11 +376,20 @@ def compute_loss(
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为“他是一个有善心的人。
messages = []
for i, text in enumerate(texts):
history_context = history_contexts[i]
message = [
{"role": "user", "content": f"{history_context}{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text},
for i, total_round in enumerate(chat_rounds):
message = []
if total_round > 1:
history_question_answer = history_contexts[i].split('USER:')
for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split('ASSISTANT:')
message += [
{"role": "user", "content": question_answer[0].strip()},
{"role": "assistant", "content": question_answer[1].strip()}
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": answers[i]}
]
messages.append(message)