From 8afb0d647f95a8bd55e935e6f586b186e9f901f1 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 7 Jun 2024 07:23:13 +0000 Subject: [PATCH] fix template --- egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py | 2 ++ egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py | 11 +++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index d8235e798..b5783d5dd 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -238,12 +238,14 @@ def decode_one_batch( ) -> Dict: """Preprocesses the data for supervised fine-tuning.""" texts = [] + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" for i, msg in enumerate(messages): texts.append( tokenizer.apply_chat_template( msg, tokenize=True, add_generation_prompt=False, + chat_template=TEMPLATE, padding="max_length", max_length=max_len, truncation=True, diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 941f49081..9b650e747 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -80,7 +80,7 @@ from icefall.utils import ( from transformers import AutoModelForCausalLM, AutoTokenizer import transformers from transformers.trainer_pt_utils import LabelSmoother -IGNORE_TOKEN_ID = LabelSmoother.ignore_index +#IGNORE_TOKEN_ID = LabelSmoother.ignore_index DEFAULT_SPEECH_TOKEN = "" def set_batch_count(model: nn.Module, batch_count: float) -> None: @@ -400,11 +400,13 @@ def compute_loss( ) -> Dict: """Preprocesses the data for supervised fine-tuning.""" texts = [] + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" for i, msg in enumerate(messages): texts.append( tokenizer.apply_chat_template( msg, tokenize=True, + chat_template=TEMPLATE, add_generation_prompt=False, padding="max_length", max_length=max_len, @@ -418,15 +420,16 @@ def compute_loss( target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID # first get the indices of the tokens - mask_prompt = False + mask_prompt = True if mask_prompt: - mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)) + mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant")) # then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644 # target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID for i in range(mask_indices[0].size(0)): row = mask_indices[0][i] col = mask_indices[1][i] - target_ids[row, :col+4] = IGNORE_TOKEN_ID + # + 2 to skip: 'assistant', '\n' + target_ids[row, :col+2] = IGNORE_TOKEN_ID attention_mask = input_ids.ne(tokenizer.pad_token_id)