remove position ids

This commit is contained in:
root 2024-06-07 09:39:28 +00:00 committed by Yuekai Zhang
parent 639feab4df
commit eb2c255e1e
3 changed files with 12 additions and 5 deletions

View File

@ -497,7 +497,7 @@ def main():
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
tokenizer.padding_side = 'left' # tokenizer.padding_side = 'left'
special_tokens_dict = { special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN] "additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
} }

View File

@ -140,18 +140,25 @@ class SPEECH_LLM(nn.Module):
speech_features = self.encoder_projector(encoder_outs) speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds = self.llm.get_input_embeddings()(input_ids)
# print("input_ids", input_ids, input_ids.shape) # print("input_ids", input_ids, input_ids.shape)
# print("labels", labels, labels.shape) # print("labels", labels, labels.shape)
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds) # print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
# print("attention_mask_before", attention_mask.shape, attention_mask)
# print(2333333333333333333333333333)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels speech_features, inputs_embeds, input_ids, attention_mask, labels
) )
# print("labels", labels, labels.shape) # print("labels", labels, labels.shape)
# print("speech_features", speech_features.shape, speech_features) # print("speech_features", speech_features.shape, speech_features)
# print("inputs_embeds after", inputs_embeds.shape, inputs_embeds) # print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
# print("attention_mask", attention_mask.shape, attention_mask)
# print("position_ids", position_ids.shape, position_ids)
# print("================================================================")
# input() # input()
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
with torch.no_grad(): with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1) preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID) acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)

View File

@ -758,7 +758,7 @@ def run(rank, world_size, args):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
tokenizer.padding_side = 'left' # tokenizer.padding_side = 'left'
special_tokens_dict = { special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN] "additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
} }
@ -820,8 +820,8 @@ def run(rank, world_size, args):
return True return True
# train_cuts = multi_dataset.train_cuts() # train_cuts = multi_dataset.train_cuts()
# train_cuts = multi_dataset.aishell_train_cuts() train_cuts = multi_dataset.aishell_train_cuts()
train_cuts = multi_dataset.aishell2_train_cuts() # train_cuts = multi_dataset.aishell2_train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: