mask unrelated labels

This commit is contained in:
root 2024-06-06 08:57:26 +00:00 committed by Yuekai Zhang
parent 3ac27d5ad4
commit 796663066f
3 changed files with 29 additions and 19 deletions

View File

@ -3,23 +3,23 @@ export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
# pip install -r whisper/requirements.txt
export CUDA_VISIBLE_DEVICES=0,1
# torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
# --max-duration 80 \
# --exp-dir ./whisper_llm_zh/exp_test \
# --speech-encoder-path-or-name tiny \
# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
# --manifest-dir data/fbank \
# --deepspeed \
# --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
# --use-flash-attn False
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir ./whisper_llm_zh/exp_qwen_0.5B \
--speech-encoder-path-or-name /mnt/samsung-t7/yuekai/asr/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
--max-duration 1 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name tiny \
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
--epoch 1 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn False
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn False
# python3 ./whisper_llm_zh/decode.py \
# --max-duration 80 \
# --exp-dir ./whisper_llm_zh/exp_qwen_0.5B \
# --speech-encoder-path-or-name /mnt/samsung-t7/yuekai/asr/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
# --epoch 1 --avg 1 \
# --manifest-dir data/fbank \
# --use-flash-attn False

View File

@ -127,11 +127,15 @@ class SPEECH_LLM(nn.Module):
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
#print("input_ids", input_ids, input_ids.shape)
#print("labels", labels, labels.shape)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
#print("labels", labels, labels.shape)
#print("speech_features", speech_features.shape)
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)

View File

@ -416,6 +416,12 @@ def compute_loss(
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
target_ids[mask_indices[0], :mask_indices[1]+4] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids