mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
fix template
This commit is contained in:
parent
16f18080be
commit
8afb0d647f
@ -238,12 +238,14 @@ def decode_one_batch(
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
chat_template=TEMPLATE,
|
||||
padding="max_length",
|
||||
max_length=max_len,
|
||||
truncation=True,
|
||||
|
@ -80,7 +80,7 @@ from icefall.utils import (
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import transformers
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
|
||||
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||
@ -400,11 +400,13 @@ def compute_loss(
|
||||
) -> Dict:
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
chat_template=TEMPLATE,
|
||||
add_generation_prompt=False,
|
||||
padding="max_length",
|
||||
max_length=max_len,
|
||||
@ -418,15 +420,16 @@ def compute_loss(
|
||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
||||
# first get the indices of the tokens
|
||||
mask_prompt = False
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
|
||||
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
|
||||
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
|
||||
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
target_ids[row, :col+4] = IGNORE_TOKEN_ID
|
||||
# + 2 to skip: 'assistant', '\n'
|
||||
target_ids[row, :col+2] = IGNORE_TOKEN_ID
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user