mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
fix multi rounds data
This commit is contained in:
parent
202d764cfb
commit
1d11662016
@ -303,6 +303,7 @@ def compute_loss(
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
print(msg,23333333333333)
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
@ -334,9 +335,14 @@ def compute_loss(
|
||||
# first get the indices of the tokens
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
default_speech_token_indices = torch.where(
|
||||
input_ids == default_speech_token_id
|
||||
)
|
||||
mask_indices = torch.where(
|
||||
input_ids == tokenizer.convert_tokens_to_ids("assistant")
|
||||
)
|
||||
print(mask_indices, default_speech_token_indices, default_speech_token_id)
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
@ -362,6 +368,7 @@ def compute_loss(
|
||||
|
||||
answers = batch["supervisions"]["text"]
|
||||
questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]]
|
||||
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
||||
answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]]
|
||||
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
||||
@ -369,11 +376,20 @@ def compute_loss(
|
||||
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为“他是一个有善心的人。
|
||||
|
||||
messages = []
|
||||
for i, text in enumerate(texts):
|
||||
history_context = history_contexts[i]
|
||||
message = [
|
||||
{"role": "user", "content": f"{history_context}{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||
{"role": "assistant", "content": text},
|
||||
for i, total_round in enumerate(chat_rounds):
|
||||
message = []
|
||||
if total_round > 1:
|
||||
history_question_answer = history_contexts[i].split('USER:')
|
||||
for j in range(total_round - 1):
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||
question_answer = history_question_answer[j].split('ASSISTANT:')
|
||||
message += [
|
||||
{"role": "user", "content": question_answer[0].strip()},
|
||||
{"role": "assistant", "content": question_answer[1].strip()}
|
||||
]
|
||||
message += [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": answers[i]}
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user