mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
remove position ids
This commit is contained in:
parent
639feab4df
commit
eb2c255e1e
@ -497,7 +497,7 @@ def main():
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
tokenizer.padding_side = 'left'
|
||||
# tokenizer.padding_side = 'left'
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
}
|
||||
|
@ -140,18 +140,25 @@ class SPEECH_LLM(nn.Module):
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
|
||||
# print("input_ids", input_ids, input_ids.shape)
|
||||
# print("labels", labels, labels.shape)
|
||||
# print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
|
||||
# print("attention_mask_before", attention_mask.shape, attention_mask)
|
||||
# print(2333333333333333333333333333)
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
# print("labels", labels, labels.shape)
|
||||
# print("speech_features", speech_features.shape, speech_features)
|
||||
# print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
|
||||
# print("attention_mask", attention_mask.shape, attention_mask)
|
||||
# print("position_ids", position_ids.shape, position_ids)
|
||||
# print("================================================================")
|
||||
# input()
|
||||
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
||||
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
|
||||
|
@ -758,7 +758,7 @@ def run(rank, world_size, args):
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
tokenizer.padding_side = 'left'
|
||||
# tokenizer.padding_side = 'left'
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
}
|
||||
@ -820,8 +820,8 @@ def run(rank, world_size, args):
|
||||
return True
|
||||
|
||||
# train_cuts = multi_dataset.train_cuts()
|
||||
# train_cuts = multi_dataset.aishell_train_cuts()
|
||||
train_cuts = multi_dataset.aishell2_train_cuts()
|
||||
train_cuts = multi_dataset.aishell_train_cuts()
|
||||
# train_cuts = multi_dataset.aishell2_train_cuts()
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user