fix template

This commit is contained in:
root 2024-06-07 07:23:13 +00:00 committed by Yuekai Zhang
parent 16f18080be
commit 8afb0d647f
2 changed files with 9 additions and 4 deletions

View File

@ -238,12 +238,14 @@ def decode_one_batch(
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="max_length",
max_length=max_len,
truncation=True,

View File

@ -80,7 +80,7 @@ from icefall.utils import (
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>"
def set_batch_count(model: nn.Module, batch_count: float) -> None:
@ -400,11 +400,13 @@ def compute_loss(
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="max_length",
max_length=max_len,
@ -418,15 +420,16 @@ def compute_loss(
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = False
mask_prompt = True
if mask_prompt:
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
target_ids[row, :col+4] = IGNORE_TOKEN_ID
# + 2 to skip: 'assistant', '\n'
target_ids[row, :col+2] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)