diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index f79e5b64c..7fc207455 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -303,6 +303,7 @@ def compute_loss( texts = [] TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" for i, msg in enumerate(messages): + print(msg,23333333333333) texts.append( tokenizer.apply_chat_template( msg, @@ -334,9 +335,14 @@ def compute_loss( # first get the indices of the tokens mask_prompt = True if mask_prompt: + default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) + default_speech_token_indices = torch.where( + input_ids == default_speech_token_id + ) mask_indices = torch.where( input_ids == tokenizer.convert_tokens_to_ids("assistant") ) + print(mask_indices, default_speech_token_indices, default_speech_token_id) for i in range(mask_indices[0].size(0)): row = mask_indices[0][i] col = mask_indices[1][i] @@ -362,6 +368,7 @@ def compute_loss( answers = batch["supervisions"]["text"] questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] + chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]] answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]] last_questions = [question.split(': ')[-1].strip() for question in questions_with_history] history_contexts = [question.rsplit(':', 1)[0].strip() for question in questions_with_history] @@ -369,11 +376,20 @@ def compute_loss( # : 对以下句子进行鉴赏:他心地善良。输出结果为“他是一个有善心的人。 messages = [] - for i, text in enumerate(texts): - history_context = history_contexts[i] - message = [ - {"role": "user", "content": f"{history_context}{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, - {"role": "assistant", "content": text}, + for i, total_round in enumerate(chat_rounds): + message = [] + if total_round > 1: + history_question_answer = history_contexts[i].split('USER:') + for j in range(total_round - 1): + # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。 + question_answer = history_question_answer[j].split('ASSISTANT:') + message += [ + {"role": "user", "content": question_answer[0].strip()}, + {"role": "assistant", "content": question_answer[1].strip()} + ] + message += [ + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, + {"role": "assistant", "content": answers[i]} ] messages.append(message)