diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index b5783d5dd..745cf5104 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -497,7 +497,7 @@ def main(): torch_dtype=torch_dtype, ) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) - tokenizer.padding_side = 'left' + # tokenizer.padding_side = 'left' special_tokens_dict = { "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] } diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index b306b7fe5..7175fab17 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -140,18 +140,25 @@ class SPEECH_LLM(nn.Module): speech_features = self.encoder_projector(encoder_outs) inputs_embeds = self.llm.get_input_embeddings()(input_ids) + # print("input_ids", input_ids, input_ids.shape) # print("labels", labels, labels.shape) # print("inputs_embeds", inputs_embeds.shape, inputs_embeds) + # print("attention_mask_before", attention_mask.shape, attention_mask) + # print(2333333333333333333333333333) inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask, labels ) # print("labels", labels, labels.shape) # print("speech_features", speech_features.shape, speech_features) # print("inputs_embeds after", inputs_embeds.shape, inputs_embeds) + # print("attention_mask", attention_mask.shape, attention_mask) + # print("position_ids", position_ids.shape, position_ids) + # print("================================================================") # input() - model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) + model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) + # model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 43bab3491..1f6d4abad 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -758,7 +758,7 @@ def run(rank, world_size, args): torch_dtype=torch_dtype, ) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) - tokenizer.padding_side = 'left' + # tokenizer.padding_side = 'left' special_tokens_dict = { "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] } @@ -820,8 +820,8 @@ def run(rank, world_size, args): return True # train_cuts = multi_dataset.train_cuts() - # train_cuts = multi_dataset.aishell_train_cuts() - train_cuts = multi_dataset.aishell2_train_cuts() + train_cuts = multi_dataset.aishell_train_cuts() + # train_cuts = multi_dataset.aishell2_train_cuts() train_cuts = train_cuts.filter(remove_short_and_long_utt) # if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: