fix decoding issue and padding to longest

This commit is contained in:
root 2024-06-11 09:04:29 +00:00 committed by Yuekai Zhang
parent eb2c255e1e
commit 271536248f
3 changed files with 61 additions and 44 deletions

View File

@ -238,7 +238,7 @@ def decode_one_batch(
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
@ -246,11 +246,16 @@ def decode_one_batch(
tokenize=True,
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="max_length",
padding="longest",
max_length=max_len,
truncation=True,
)
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == 'right':
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
else:
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
input_ids = torch.tensor(texts, dtype=torch.int)
@ -481,33 +486,36 @@ def main():
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16
torch_dtype=torch.float16
tokenizer.padding_side = 'left'
else:
attn_implementation = "eager"
torch_dtype=torch.float16
tokenizer.padding_side = 'right'
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
# tokenizer.padding_side = 'left'
special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.bos_token_id = tokenizer.bos_token_id
llm.config.eos_token_id = tokenizer.eos_token_id
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
model = SPEECH_LLM(
speech_encoder,

View File

@ -1,12 +1,12 @@
from torch import nn
import torch
from transformers.trainer_pt_utils import LabelSmoother
from icefall.dist import get_rank
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module):
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
def __init__(self, encoder_dim, llm_dim, downsample_rate=1):
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__()
self.downsample_rate = downsample_rate
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
@ -47,7 +47,6 @@ class SPEECH_LLM(nn.Module):
param.requires_grad = False
self.llm.eval()
self.encoder_projector = encoder_projector
self.encoder_outputs_downsample_rate = 4
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
num_speechs, speech_len, embed_dim = speech_features.shape
@ -134,33 +133,39 @@ class SPEECH_LLM(nn.Module):
labels: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
# downsample encoder_outs by 4
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
# print("input_ids", input_ids, input_ids.shape)
# print("labels", labels, labels.shape)
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
# print("attention_mask_before", attention_mask.shape, attention_mask)
# print(2333333333333333333333333333)
enable_logging = False
rank = get_rank()
# log only on rank 0, training using deep
if enable_logging and rank == 0:
print("input_ids", input_ids, input_ids.shape)
print("labels", labels, labels.shape)
print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
print("attention_mask_before", attention_mask.shape, attention_mask)
print(2333333333333333333333333333)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
# print("labels", labels, labels.shape)
# print("speech_features", speech_features.shape, speech_features)
# print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
# print("attention_mask", attention_mask.shape, attention_mask)
# print("position_ids", position_ids.shape, position_ids)
# print("================================================================")
# input()
if enable_logging and rank == 0:
print("speech_features", speech_features.shape, speech_features)
print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
print("attention_mask", attention_mask.shape, attention_mask)
print("position_ids", position_ids.shape, position_ids)
print("labels", labels, labels.shape)
print("================================================================")
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
if enable_logging and rank == 0:
print("preds", preds, preds.shape)
print(4555555555555555555555555555555555555555555)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
return model_outputs, acc
@ -173,9 +178,6 @@ class SPEECH_LLM(nn.Module):
):
encoder_outs = self.encoder(fbank)
# downsample encoder_outs by 4
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
@ -196,10 +198,7 @@ class SPEECH_LLM(nn.Module):
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id
)
# print(generated_ids, input_ids)
# generated_ids = [
# output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
# ]
return generated_ids

View File

@ -106,7 +106,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=4,
default=1,
help="Downsample rate for the encoder projector.",
)
@ -287,7 +287,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 10000,
"valid_interval": 5000,
"env_info": get_env_info(),
}
)
@ -408,12 +408,17 @@ def compute_loss(
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="max_length",
padding="longest", # FIX me change padding to longest
max_length=max_len,
truncation=True,
)
)
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == 'right':
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
else:
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone()
@ -423,8 +428,7 @@ def compute_loss(
mask_prompt = True
if mask_prompt:
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
# then mask all tokens before the first token e.g. 151646 (speech), 151645 <assistant>, 198 \n
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
@ -526,7 +530,7 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["acc"] = acc
info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames
return loss, info
@ -743,22 +747,24 @@ def run(rank, world_size, args):
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16
torch_dtype=torch.float16
tokenizer.padding_side = 'left'
else:
attn_implementation = "eager"
torch_dtype=torch.float16
tokenizer.padding_side = 'right'
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
# tokenizer.padding_side = 'left'
special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
}
@ -766,7 +772,7 @@ def run(rank, world_size, args):
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
model = SPEECH_LLM(
speech_encoder,
@ -774,6 +780,10 @@ def run(rank, world_size, args):
encoder_projector,
)
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")