diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index f05e5c1ac..f79e5b64c 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -347,42 +347,6 @@ def compute_loss( return input_ids, attention_mask, target_ids - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - max_frames = params.max_duration * 1000 // params.frame_shift_ms allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) batch = filter_uneven_sized_batch(batch, allowed_max_frames) @@ -400,15 +364,15 @@ def compute_loss( questions_with_history = [cut.custom["question"] for cut in batch["supervisions"]["cut"]] answer_cosyvoice_speech_token = [cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]] last_questions = [question.split(': ')[-1].strip() for question in questions_with_history] - - texts = batch["supervisions"]["text"] - # remove spaces in texts - texts = [normalize_text_alimeeting(text) for text in texts] + history_contexts = [question.rsplit(':', 1)[0].strip() for question in questions_with_history] + # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。: 告诉我如何烹饪鸡肉 + # : 对以下句子进行鉴赏:他心地善良。输出结果为“他是一个有善心的人。 messages = [] for i, text in enumerate(texts): + history_context = history_contexts[i] message = [ - {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, + {"role": "user", "content": f"{history_context}{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "assistant", "content": text}, ] messages.append(message)