mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
mask unrelated labels
This commit is contained in:
parent
3ac27d5ad4
commit
796663066f
@ -3,23 +3,23 @@ export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall_llm
|
|||||||
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
# pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
# pip install -r whisper/requirements.txt
|
# pip install -r whisper/requirements.txt
|
||||||
export CUDA_VISIBLE_DEVICES=0,1
|
export CUDA_VISIBLE_DEVICES=0,1
|
||||||
# torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
|
torchrun --nproc_per_node 2 ./whisper_llm_zh/train.py \
|
||||||
# --max-duration 80 \
|
--max-duration 1 \
|
||||||
# --exp-dir ./whisper_llm_zh/exp_test \
|
--exp-dir ./whisper_llm_zh/exp_test \
|
||||||
# --speech-encoder-path-or-name tiny \
|
--speech-encoder-path-or-name tiny \
|
||||||
# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
|
||||||
# --manifest-dir data/fbank \
|
|
||||||
# --deepspeed \
|
|
||||||
# --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
|
||||||
# --use-flash-attn False
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
python3 ./whisper_llm_zh/decode.py \
|
|
||||||
--max-duration 80 \
|
|
||||||
--exp-dir ./whisper_llm_zh/exp_qwen_0.5B \
|
|
||||||
--speech-encoder-path-or-name /mnt/samsung-t7/yuekai/asr/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
|
||||||
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
--llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||||
--epoch 1 --avg 1 \
|
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--use-flash-attn False
|
--deepspeed \
|
||||||
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
|
--use-flash-attn False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# python3 ./whisper_llm_zh/decode.py \
|
||||||
|
# --max-duration 80 \
|
||||||
|
# --exp-dir ./whisper_llm_zh/exp_qwen_0.5B \
|
||||||
|
# --speech-encoder-path-or-name /mnt/samsung-t7/yuekai/asr/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
|
# --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \
|
||||||
|
# --epoch 1 --avg 1 \
|
||||||
|
# --manifest-dir data/fbank \
|
||||||
|
# --use-flash-attn False
|
@ -127,11 +127,15 @@ class SPEECH_LLM(nn.Module):
|
|||||||
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
encoder_outs = encoder_outs[:, ::self.encoder_outputs_downsample_rate]
|
||||||
|
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
|
||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
|
#print("input_ids", input_ids, input_ids.shape)
|
||||||
|
#print("labels", labels, labels.shape)
|
||||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
||||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
|
#print("labels", labels, labels.shape)
|
||||||
|
#print("speech_features", speech_features.shape)
|
||||||
|
|
||||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||||
|
|
||||||
|
@ -416,6 +416,12 @@ def compute_loss(
|
|||||||
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||||
target_ids = input_ids.clone()
|
target_ids = input_ids.clone()
|
||||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||||
|
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
||||||
|
# first get the indices of the tokens
|
||||||
|
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN))
|
||||||
|
# then mask all tokens before the first token e.g. 151646 (speech), 151645, 198, 151644
|
||||||
|
target_ids[mask_indices[0], :mask_indices[1]+4] = IGNORE_TOKEN_ID
|
||||||
|
|
||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
return input_ids, attention_mask, target_ids
|
return input_ids, attention_mask, target_ids
|
||||||
|
Loading…
x
Reference in New Issue
Block a user