diff --git a/egs/speech_llm/ASR_LLM/debug.sh b/egs/speech_llm/ASR_LLM/debug.sh index 255de44cd..644167275 100755 --- a/egs/speech_llm/ASR_LLM/debug.sh +++ b/egs/speech_llm/ASR_LLM/debug.sh @@ -3,23 +3,23 @@ export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm # pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html # pip install -r whisper/requirements.txt export CUDA_VISIBLE_DEVICES=0,1 -# torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \ -# --max-duration 80 \ -# --exp-dir ./whisper_llm_zh/exp_test \ -# --speech-encoder-path-or-name tiny \ -# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ -# --manifest-dir data/fbank \ -# --deepspeed \ -# --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ -# --use-flash-attn False - - - -python3 ./whisper_llm_zh/decode.py \ - --max-duration 80 \ - --exp-dir ./whisper_llm_zh/exp_qwen_0.5B \ - --speech-encoder-path-or-name /mnt/samsung-t7/yuekai/asr/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ +torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \ + --max-duration 1 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name tiny \ --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ - --epoch 1 --avg 1 \ --manifest-dir data/fbank \ - --use-flash-attn False \ No newline at end of file + --deepspeed \ + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn False + + + +# python3 ./whisper_llm_zh/decode.py \ +# --max-duration 80 \ +# --exp-dir ./whisper_llm_zh/exp_qwen_0.5B \ +# --speech-encoder-path-or-name /mnt/samsung-t7/yuekai/asr/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ +# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ +# --epoch 1 --avg 1 \ +# --manifest-dir data/fbank \ +# --use-flash-attn False \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index 5d4585b11..eec9d7812 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -127,11 +127,15 @@ class SPEECH_LLM(nn.Module): encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate] speech_features = self.encoder_projector(encoder_outs) - + inputs_embeds = self.llm.get_input_embeddings()(input_ids) + #print("input_ids", input_ids, input_ids.shape) + #print("labels", labels, labels.shape) inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask, labels ) + #print("labels", labels, labels.shape) + #print("speech_features", speech_features.shape) model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 3315a5b53..dab19ca8b 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -416,6 +416,12 @@ def compute_loss( # response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] target_ids = input_ids.clone() target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID + # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID + # first get the indices of the tokens + mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)) + # then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644 + target_ids[mask_indices[0], :mask_indices[1]+4] = IGNORE_TOKEN_ID + attention_mask = input_ids.ne(tokenizer.pad_token_id) return input_ids, attention_mask, target_ids