mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
fix template
This commit is contained in:
parent
16f18080be
commit
8afb0d647f
@ -238,12 +238,14 @@ def decode_one_batch(
|
|||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Preprocesses the data for supervised fine-tuning."""
|
"""Preprocesses the data for supervised fine-tuning."""
|
||||||
texts = []
|
texts = []
|
||||||
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
texts.append(
|
texts.append(
|
||||||
tokenizer.apply_chat_template(
|
tokenizer.apply_chat_template(
|
||||||
msg,
|
msg,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
|
chat_template=TEMPLATE,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=max_len,
|
max_length=max_len,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
@ -80,7 +80,7 @@ from icefall.utils import (
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.trainer_pt_utils import LabelSmoother
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||||
|
|
||||||
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||||
@ -400,11 +400,13 @@ def compute_loss(
|
|||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Preprocesses the data for supervised fine-tuning."""
|
"""Preprocesses the data for supervised fine-tuning."""
|
||||||
texts = []
|
texts = []
|
||||||
|
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
texts.append(
|
texts.append(
|
||||||
tokenizer.apply_chat_template(
|
tokenizer.apply_chat_template(
|
||||||
msg,
|
msg,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
|
chat_template=TEMPLATE,
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=max_len,
|
max_length=max_len,
|
||||||
@ -418,15 +420,16 @@ def compute_loss(
|
|||||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||||
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
||||||
# first get the indices of the tokens
|
# first get the indices of the tokens
|
||||||
mask_prompt = False
|
mask_prompt = True
|
||||||
if mask_prompt:
|
if mask_prompt:
|
||||||
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
|
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
|
||||||
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
|
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
|
||||||
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
|
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
|
||||||
for i in range(mask_indices[0].size(0)):
|
for i in range(mask_indices[0].size(0)):
|
||||||
row = mask_indices[0][i]
|
row = mask_indices[0][i]
|
||||||
col = mask_indices[1][i]
|
col = mask_indices[1][i]
|
||||||
target_ids[row, :col+4] = IGNORE_TOKEN_ID
|
# + 2 to skip: 'assistant', '\n'
|
||||||
|
target_ids[row, :col+2] = IGNORE_TOKEN_ID
|
||||||
|
|
||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user