fix decoding issue and padding to longest

This commit is contained in:
root 2024-06-11 09:04:29 +00:00 committed by Yuekai Zhang
parent eb2c255e1e
commit 271536248f
3 changed files with 61 additions and 44 deletions

View File

@ -196,7 +196,7 @@ def get_parser():
default=True, default=True,
help="Whether to use flash attention.", help="Whether to use flash attention.",
) )
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -238,7 +238,7 @@ def decode_one_batch(
) -> Dict: ) -> Dict:
"""Preprocesses the data for supervised fine-tuning.""" """Preprocesses the data for supervised fine-tuning."""
texts = [] texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
texts.append( texts.append(
tokenizer.apply_chat_template( tokenizer.apply_chat_template(
@ -246,11 +246,16 @@ def decode_one_batch(
tokenize=True, tokenize=True,
add_generation_prompt=False, add_generation_prompt=False,
chat_template=TEMPLATE, chat_template=TEMPLATE,
padding="max_length", padding="longest",
max_length=max_len, max_length=max_len,
truncation=True, truncation=True,
) )
) )
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == 'right':
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
else:
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
input_ids = torch.tensor(texts, dtype=torch.int) input_ids = torch.tensor(texts, dtype=torch.int)
@ -481,33 +486,36 @@ def main():
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn: if params.use_flash_attn:
attn_implementation = "flash_attention_2" attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 # torch_dtype=torch.bfloat16
torch_dtype=torch.float16 torch_dtype=torch.float16
tokenizer.padding_side = 'left'
else: else:
attn_implementation = "eager" attn_implementation = "eager"
torch_dtype=torch.float16 torch_dtype=torch.float16
tokenizer.padding_side = 'right'
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name, params.llm_path_or_name,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
# tokenizer.padding_side = 'left'
special_tokens_dict = { special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN] "additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
} }
tokenizer.add_special_tokens(special_tokens_dict) tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.bos_token_id llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.eos_token_id llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size) encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
model = SPEECH_LLM( model = SPEECH_LLM(
speech_encoder, speech_encoder,

View File

@ -1,12 +1,12 @@
from torch import nn from torch import nn
import torch import torch
from transformers.trainer_pt_utils import LabelSmoother from transformers.trainer_pt_utils import LabelSmoother
from icefall.dist import get_rank
IGNORE_TOKEN_ID = LabelSmoother.ignore_index IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module): class EncoderProjector(nn.Module):
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py # https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
def __init__(self, encoder_dim, llm_dim, downsample_rate=1): def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__() super().__init__()
self.downsample_rate = downsample_rate self.downsample_rate = downsample_rate
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim) self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
@ -47,7 +47,6 @@ class SPEECH_LLM(nn.Module):
param.requires_grad = False param.requires_grad = False
self.llm.eval() self.llm.eval()
self.encoder_projector = encoder_projector self.encoder_projector = encoder_projector
self.encoder_outputs_downsample_rate = 4
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None): def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
num_speechs, speech_len, embed_dim = speech_features.shape num_speechs, speech_len, embed_dim = speech_features.shape
@ -134,33 +133,39 @@ class SPEECH_LLM(nn.Module):
labels: torch.LongTensor = None, labels: torch.LongTensor = None,
): ):
encoder_outs = self.encoder(fbank) encoder_outs = self.encoder(fbank)
# downsample encoder_outs by 4
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs) speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds = self.llm.get_input_embeddings()(input_ids)
enable_logging = False
rank = get_rank()
# print("input_ids", input_ids, input_ids.shape) # log only on rank 0, training using deep
# print("labels", labels, labels.shape) if enable_logging and rank == 0:
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds) print("input_ids", input_ids, input_ids.shape)
# print("attention_mask_before", attention_mask.shape, attention_mask) print("labels", labels, labels.shape)
# print(2333333333333333333333333333) print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
print("attention_mask_before", attention_mask.shape, attention_mask)
print(2333333333333333333333333333)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels speech_features, inputs_embeds, input_ids, attention_mask, labels
) )
# print("labels", labels, labels.shape) if enable_logging and rank == 0:
# print("speech_features", speech_features.shape, speech_features) print("speech_features", speech_features.shape, speech_features)
# print("inputs_embeds after", inputs_embeds.shape, inputs_embeds) print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
# print("attention_mask", attention_mask.shape, attention_mask) print("attention_mask", attention_mask.shape, attention_mask)
# print("position_ids", position_ids.shape, position_ids) print("position_ids", position_ids.shape, position_ids)
# print("================================================================") print("labels", labels, labels.shape)
# input() print("================================================================")
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) # model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
with torch.no_grad(): with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1) preds = torch.argmax(model_outputs.logits, -1)
if enable_logging and rank == 0:
print("preds", preds, preds.shape)
print(4555555555555555555555555555555555555555555)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID) acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
return model_outputs, acc return model_outputs, acc
@ -173,9 +178,6 @@ class SPEECH_LLM(nn.Module):
): ):
encoder_outs = self.encoder(fbank) encoder_outs = self.encoder(fbank)
# downsample encoder_outs by 4
# encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs) speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16) speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds = self.llm.get_input_embeddings()(input_ids)
@ -196,10 +198,7 @@ class SPEECH_LLM(nn.Module):
eos_token_id=self.llm.config.eos_token_id, eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id pad_token_id=self.llm.config.pad_token_id
) )
# print(generated_ids, input_ids)
# generated_ids = [
# output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
# ]
return generated_ids return generated_ids

View File

@ -106,7 +106,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-projector-ds-rate", "--encoder-projector-ds-rate",
type=int, type=int,
default=4, default=1,
help="Downsample rate for the encoder projector.", help="Downsample rate for the encoder projector.",
) )
@ -287,7 +287,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 10000, "valid_interval": 5000,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -408,12 +408,17 @@ def compute_loss(
tokenize=True, tokenize=True,
chat_template=TEMPLATE, chat_template=TEMPLATE,
add_generation_prompt=False, add_generation_prompt=False,
padding="max_length", padding="longest", # FIX me change padding to longest
max_length=max_len, max_length=max_len,
truncation=True, truncation=True,
) )
) )
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == 'right':
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
else:
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
input_ids = torch.tensor(texts, dtype=torch.int) input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] # response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone() target_ids = input_ids.clone()
@ -423,8 +428,7 @@ def compute_loss(
mask_prompt = True mask_prompt = True
if mask_prompt: if mask_prompt:
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant")) mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644 # then mask all tokens before the first token e.g. 151646 (speech), 151645 <assistant>, 198 \n
# target_ids[mask_indices[0], :mask_indices[1]+3] = IGNORE_TOKEN_ID
for i in range(mask_indices[0].size(0)): for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i] row = mask_indices[0][i]
col = mask_indices[1][i] col = mask_indices[1][i]
@ -526,7 +530,7 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["acc"] = acc info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames
return loss, info return loss, info
@ -743,22 +747,24 @@ def run(rank, world_size, args):
speech_encoder = whisper_model.encoder speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn: if params.use_flash_attn:
attn_implementation = "flash_attention_2" attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 # torch_dtype=torch.bfloat16
torch_dtype=torch.float16 torch_dtype=torch.float16
tokenizer.padding_side = 'left'
else: else:
attn_implementation = "eager" attn_implementation = "eager"
torch_dtype=torch.float16 torch_dtype=torch.float16
tokenizer.padding_side = 'right'
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name, params.llm_path_or_name,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
# tokenizer.padding_side = 'left'
special_tokens_dict = { special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN] "additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
} }
@ -766,7 +772,7 @@ def run(rank, world_size, args):
llm.config.pad_token_id = tokenizer.pad_token_id llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size) encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
model = SPEECH_LLM( model = SPEECH_LLM(
speech_encoder, speech_encoder,
@ -774,6 +780,10 @@ def run(rank, world_size, args):
encoder_projector, encoder_projector,
) )
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")