fix template

This commit is contained in:
root 2024-06-07 07:23:13 +00:00 committed by Yuekai Zhang
parent 16f18080be
commit 8afb0d647f
2 changed files with 9 additions and 4 deletions

View File

@ -238,12 +238,14 @@ def decode_one_batch(
) -> Dict: ) -> Dict:
"""Preprocesses the data for supervised fine-tuning.""" """Preprocesses the data for supervised fine-tuning."""
texts = [] texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
texts.append( texts.append(
tokenizer.apply_chat_template( tokenizer.apply_chat_template(
msg, msg,
tokenize=True, tokenize=True,
add_generation_prompt=False, add_generation_prompt=False,
chat_template=TEMPLATE,
padding="max_length", padding="max_length",
max_length=max_len, max_length=max_len,
truncation=True, truncation=True,

View File

@ -80,7 +80,7 @@ from icefall.utils import (
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers import transformers
from transformers.trainer_pt_utils import LabelSmoother from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index #IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>" DEFAULT_SPEECH_TOKEN = "<speech>"
def set_batch_count(model: nn.Module, batch_count: float) -> None: def set_batch_count(model: nn.Module, batch_count: float) -> None:
@ -400,11 +400,13 @@ def compute_loss(
) -> Dict: ) -> Dict:
"""Preprocesses the data for supervised fine-tuning.""" """Preprocesses the data for supervised fine-tuning."""
texts = [] texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
texts.append( texts.append(
tokenizer.apply_chat_template( tokenizer.apply_chat_template(
msg, msg,
tokenize=True, tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False, add_generation_prompt=False,
padding="max_length", padding="max_length",
max_length=max_len, max_length=max_len,
@ -418,15 +420,16 @@ def compute_loss(
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens # first get the indices of the tokens
mask_prompt = False mask_prompt = True
if mask_prompt: if mask_prompt:
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)) mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644 # then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID # target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
for i in range(mask_indices[0].size(0)): for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i] row = mask_indices[0][i]
col = mask_indices[1][i] col = mask_indices[1][i]
target_ids[row, :col+4] = IGNORE_TOKEN_ID # + 2 to skip: 'assistant', '\n'
target_ids[row, :col+2] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id) attention_mask = input_ids.ne(tokenizer.pad_token_id)