diff --git a/egs/speech_llm/ASR_LLM/README.md b/egs/speech_llm/ASR_LLM/README.md index 1e60c733c..171240db0 100644 --- a/egs/speech_llm/ASR_LLM/README.md +++ b/egs/speech_llm/ASR_LLM/README.md @@ -1,39 +1,20 @@ # Introduction -This recipe includes scripts for training Zipformer model using multiple Chinese datasets. +This recipe includes scripts for training [Qwen-Audio](https://github.com/QwenLM/Qwen-Audio/tree/main) style model using multiple datasets. -# Included Training Sets -1. THCHS-30 -2. AiShell-{1,2,4} -3. ST-CMDS -4. Primewords -5. MagicData -6. Aidatatang_200zh -7. AliMeeting -8. WeNetSpeech -9. KeSpeech-ASR +
+

+ +

+
-|Datset| Number of hours| URL| -|---|---:|---| -|**TOTAL**|14,106|---| -|THCHS-30|35|https://www.openslr.org/18/| -|AiShell-1|170|https://www.openslr.org/33/| -|AiShell-2|1,000|http://www.aishelltech.com/aishell_2| -|AiShell-4|120|https://www.openslr.org/111/| -|ST-CMDS|110|https://www.openslr.org/38/| -|Primewords|99|https://www.openslr.org/47/| -|aidatatang_200zh|200|https://www.openslr.org/62/| -|MagicData|755|https://www.openslr.org/68/| -|AliMeeting|100|https://openslr.org/119/| -|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech| -|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech| +[./RESULTS.md](./RESULTS.md) contains the latest results. +# ASR_LLM -# Included Test Sets -1. Aishell-{1,2,4} -2. Aidatatang_200zh -3. AliMeeting -4. MagicData -5. KeSpeech-ASR -6. WeNetSpeech +The following table lists the folders for different tasks. + +| | Speech Encoder | LLM | Comment | +|---------------------------------------|---------------------|--------------------|---------------------------------------------------| +| [whisper_llm_zh](./whisper_llm_zh) | Whisper | Qwen2 | [Using multiple Chinese datasets](https://github.com/k2-fsa/icefall/tree/master/egs/multi_zh-hans/ASR) | diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md index a7f3bc4f7..773824c05 100644 --- a/egs/speech_llm/ASR_LLM/RESULTS.md +++ b/egs/speech_llm/ASR_LLM/RESULTS.md @@ -1,116 +1,62 @@ ## Results -### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2 -#### Whisper -[./whisper](./whisper) +### whisper_llm_zh finetuning results -Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search. - -| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | -|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------| -| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting | -| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 | +| Training Dataset | Speech Encoder | LLM | Projector |Comment | CER | +| -------------------------| ----------------|------|--------------------------------------------------|-----|--| +| Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 Test 3.76% | + Command for training is: ```bash -pip install -r whisper/requirements.txt +pip install -r whisper_llm_zh/requirements.txt -# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54 +pip install huggingface_hub['cli'] +mkdir -p models/whisper models/qwen -torchrun --nproc-per-node 8 ./whisper/train.py \ +# For aishell fine-tuned whisper model +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +# For multi-hans fine-tuned whisper model +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + +# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct + +torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --max-duration 200 \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ + --manifest-dir data/fbank \ --deepspeed \ - --deepspeed_config ./whisper/ds_config_zero1.json + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True ``` Command for decoding using fine-tuned models: ```bash -git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper -ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt +mkdir -p models/whisper models/qwen models/checkpoint +huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ +# For aishell fine-tuned whisper model +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +# For multi-hans fine-tuned whisper model +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + +huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct + +mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B +ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt + +python3 ./whisper_llm_zh/decode.py \ + --max-duration 80 \ + --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name models/qwen \ --epoch 999 --avg 1 \ - --beam-size 10 --max-duration 50 + --manifest-dir data/fbank \ + --use-flash-attn True \ + --use-lora True --dataset aishell ``` - -Fine-tuned models, training logs, decoding logs, tensorboard and decoding results -are available at - - - -### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model - -This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall. - -#### Non-streaming (with CTC head) - -Best results (num of params : ~69M): - -The training command: - -``` -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 20 \ - --use-fp16 1 \ - --max-duration 600 \ - --num-workers 8 \ - --use-ctc 1 -``` - -The decoding command: - -``` -./zipformer/decode.py \ - --epoch 20 \ - --avg 1 \ - --use-ctc 1 -``` - -Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled). - -| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | -|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 | -| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 | - -Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ - -#### Non-streaming - -Best results (num of params : ~69M): - -The training command: - -``` -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 20 \ - --use-fp16 1 \ - --max-duration 600 \ - --num-workers 8 -``` - -The decoding command: - -``` -./zipformer/decode.py \ - --epoch 20 \ - --avg 1 -``` - -Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled). - -| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | -|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 | - - -Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ diff --git a/egs/speech_llm/ASR_LLM/assets/framework.png b/egs/speech_llm/ASR_LLM/assets/framework.png new file mode 100644 index 000000000..dc48bda78 Binary files /dev/null and b/egs/speech_llm/ASR_LLM/assets/framework.png differ diff --git a/egs/speech_llm/ASR_LLM/prepare.sh b/egs/speech_llm/ASR_LLM/prepare.sh new file mode 100644 index 000000000..6f5ed5448 --- /dev/null +++ b/egs/speech_llm/ASR_LLM/prepare.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=0 +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "stage 0: Download whisper-large-v2 aishell 1 fbank feature from huggingface" + + # pip install huggingface_hub['cli'] + # for aishell 1 + huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse + +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface" + + # for multi-hans-zh + huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse + huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse + huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "stage 2: Download whisper-large-v2 speechio test sets fbank feature from huggingface" + + # for speechio test sets + mkdir data_speechio + huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio + mv data_speechio/fbank/* data/fbank +fi diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index f3bdb452c..882ce4fbf 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -20,21 +20,34 @@ """ Usage: # Command for decoding using fine-tuned models: -git lfs install -git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper -ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt -python3 ./whisper/decode.py \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ +pip install huggingface_hub['cli'] +mkdir -p models/whisper models/qwen models/checkpoint +huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B + +# For aishell fine-tuned whisper model +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +# For multi-hans fine-tuned whisper model +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + +huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct + +mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B +ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt + +python3 ./whisper_llm_zh/decode.py \ + --max-duration 80 \ + --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name models/qwen \ --epoch 999 --avg 1 \ - --beam-size 10 --max-duration 50 - + --manifest-dir data/fbank \ + --use-flash-attn True \ + --use-lora True --dataset aishell """ import argparse import logging -import re from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -42,18 +55,17 @@ from typing import Dict, List, Optional, Tuple import k2 import torch import torch.nn as nn +import transformers import whisper from asr_datamodule import AsrDataModule from lhotse.cut import Cut +from model import SPEECH_LLM, EncoderProjector from multi_dataset import MultiDataset -#from tn.chinese.normalizer import Normalizer -#from whisper.normalizers import BasicTextNormalizer -#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -#from zhconv import convert -import transformers +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from train import DEFAULT_SPEECH_TOKEN from transformers import AutoModelForCausalLM, AutoTokenizer -from model import EncoderProjector, SPEECH_LLM +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward + from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.env import get_env_info from icefall.utils import ( @@ -63,8 +75,7 @@ from icefall.utils import ( str2bool, write_error_stats, ) -from train import DEFAULT_SPEECH_TOKEN -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + def average_checkpoints( filenames: List[Path], device: torch.device = torch.device("cpu") @@ -117,6 +128,7 @@ def average_checkpoints( return avg + def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--llm-path-or-name", @@ -135,7 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-projector-ds-rate", type=int, - default=1, + default=8, help="Downsample rate for the encoder projector.", ) @@ -149,10 +161,11 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--use-lora", type=str2bool, - default=False, - help="Whether to use lora to fine-tune llm.", + default=True, + help="Whether to use lora fine-tuned llm checkpoint.", ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -247,6 +260,7 @@ def decode_one_batch( Returns: Return a dict, whose key may be "beam-search". """ + def preprocess( messages, tokenizer: transformers.PreTrainedTokenizer, @@ -268,10 +282,16 @@ def decode_one_batch( ) ) max_len_texts = max([len(text) for text in texts]) - if tokenizer.padding_side == 'right': - texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts] + if tokenizer.padding_side == "right": + texts = [ + text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + for text in texts + ] else: - texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts] + texts = [ + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text + for text in texts + ] input_ids = torch.tensor(texts, dtype=torch.int) @@ -302,16 +322,18 @@ def decode_one_batch( feature_len = supervisions["num_frames"] feature_len = feature_len.to(device, dtype=dtype) - messages = [[ - {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, - {"role": "assistant", "content": ""}, - ]] * len(feature) + messages = [ + [ + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, + {"role": "assistant", "content": ""}, + ] + ] * len(feature) - input_ids, attention_mask = preprocess( - messages, tokenizer, max_len=128 + input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128) + + generated_ids = model.decode( + feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) ) - - generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) return {"beam-search": hyps} @@ -497,14 +519,14 @@ def main(): if params.use_flash_attn: attn_implementation = "flash_attention_2" - # torch_dtype=torch.bfloat16 - torch_dtype=torch.float16 - tokenizer.padding_side = 'left' + # torch_dtype=torch.bfloat16 FIX ME + torch_dtype = torch.float16 + tokenizer.padding_side = "left" else: attn_implementation = "eager" - torch_dtype=torch.float16 - tokenizer.padding_side = 'right' + torch_dtype = torch.float16 + tokenizer.padding_side = "right" llm = AutoModelForCausalLM.from_pretrained( params.llm_path_or_name, @@ -515,23 +537,33 @@ def main(): lora_config = LoraConfig( r=64, lora_alpha=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"], + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "up_proj", + "gate_proj", + "down_proj", + ], task_type="CAUSAL_LM", ) llm = get_peft_model(llm, lora_config) llm.print_trainable_parameters() - special_tokens_dict = { - "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] - } + special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} tokenizer.add_special_tokens(special_tokens_dict) llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") - llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) + llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( + DEFAULT_SPEECH_TOKEN + ) - encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate) + encoder_projector = EncoderProjector( + speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate + ) model = SPEECH_LLM( speech_encoder, @@ -539,7 +571,6 @@ def main(): encoder_projector, ) - if params.avg > 1: start = params.epoch - params.avg + 1 assert start >= 1, start @@ -579,7 +610,7 @@ def main(): # if c.duration > 30.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False return True diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py deleted file mode 120000 index e9d239fff..000000000 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index 440724db2..829ef4e2d 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -1,11 +1,20 @@ -from torch import nn import torch +from torch import nn from transformers.trainer_pt_utils import LabelSmoother -from icefall.dist import get_rank + IGNORE_TOKEN_ID = LabelSmoother.ignore_index + class EncoderProjector(nn.Module): -# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py + """ + The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model. + Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py. + Args: + encoder_dim (:obj:`int`): The dimension of the encoder outputs. + llm_dim (:obj:`int`): The dimension of the language model. + downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use. + """ + def __init__(self, encoder_dim, llm_dim, downsample_rate=5): super().__init__() self.downsample_rate = downsample_rate @@ -20,16 +29,30 @@ class EncoderProjector(nn.Module): if num_frames_to_discard > 0: x = x[:, :-num_frames_to_discard, :] seq_len = x.size(1) - + x = x.contiguous() - x = x.view(batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate) + x = x.view( + batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate + ) x = self.linear1(x) x = self.relu(x) x = self.linear2(x) return x + class SPEECH_LLM(nn.Module): + """ + The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector. + The encoder is used to extract speech features from the input speech signal. + The encoder projector is used to project the encoder outputs to the same dimension as the language model. + The language model is used to generate the text from the speech features. + Args: + encoder (:obj:`nn.Module`): The encoder module. + llm (:obj:`nn.Module`): The language model module. + encoder_projector (:obj:`nn.Module`): The encoder projector module. + """ + def __init__( self, encoder: nn.Module, @@ -41,23 +64,46 @@ class SPEECH_LLM(nn.Module): self.llm = llm self.encoder_projector = encoder_projector - def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None): + def _merge_input_ids_with_speech_features( + self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None + ): + """ + Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens + with the speech features and padding the input_ids to the maximum length of the speech features. + Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277. + Args: + speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids. + inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids. + input_ids (:obj:`torch.Tensor`): The input ids to merge. + attention_mask (:obj:`torch.Tensor`): The attention mask to merge. + labels (:obj:`torch.Tensor`, `optional`): The labels to merge. + Returns: + :obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids. + """ num_speechs, speech_len, embed_dim = speech_features.shape batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)) + left_padding = not torch.sum( + input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id) + ) # 1. Create a mask to know where special speech tokens are special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1) # Compute the maximum embed dimension - max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length - batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id) + max_embed_dim = ( + num_special_speech_tokens.max() * (speech_len - 1) + ) + sequence_length + batch_indices, non_speech_indices = torch.where( + input_ids != self.llm.config.default_speech_token_id + ) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged speech-text sequence. # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens. # `torch.cumsum` computes how each speech token shifts subsequent text token positions. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1 + new_token_positions = ( + torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1 + ) nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_speech_pad[:, None] # offset for left padding @@ -65,14 +111,24 @@ class SPEECH_LLM(nn.Module): # 3. Create the full embedding, already padded to the maximum position final_embedding = torch.zeros( - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + batch_size, + max_embed_dim, + embed_dim, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, ) final_attention_mask = torch.zeros( - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + batch_size, + max_embed_dim, + dtype=attention_mask.dtype, + device=inputs_embeds.device, ) if labels is not None: final_labels = torch.full( - (batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device + (batch_size, max_embed_dim), + IGNORE_TOKEN_ID, + dtype=input_ids.dtype, + device=input_ids.device, ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. @@ -86,17 +142,28 @@ class SPEECH_LLM(nn.Module): # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_speech_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_speech_indices] + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ + batch_indices, non_speech_indices + ] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[ + batch_indices, non_speech_indices + ] if labels is not None: - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_speech_indices] + final_labels[batch_indices, text_to_overwrite] = labels[ + batch_indices, non_speech_indices + ] # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835) speech_to_overwrite = torch.full( - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + (batch_size, max_embed_dim), + True, + dtype=torch.bool, + device=inputs_embeds.device, ) speech_to_overwrite[batch_indices, text_to_overwrite] = False - speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[:, None].to(target_device) + speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[ + :, None + ].to(target_device) if speech_to_overwrite.sum() != speech_features.shape[:-1].numel(): raise ValueError( @@ -104,12 +171,18 @@ class SPEECH_LLM(nn.Module): f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation." ) - final_embedding[speech_to_overwrite] = speech_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_embedding[speech_to_overwrite] = ( + speech_features.contiguous().reshape(-1, embed_dim).to(target_device) + ) final_attention_mask |= speech_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_( + (final_attention_mask == 0), 1 + ) # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. - batch_indices, pad_indices = torch.where(input_ids == self.llm.config.pad_token_id) + batch_indices, pad_indices = torch.where( + input_ids == self.llm.config.pad_token_id + ) indices_to_mask = new_token_positions[batch_indices, pad_indices] final_embedding[batch_indices, indices_to_mask] = 0 @@ -119,62 +192,59 @@ class SPEECH_LLM(nn.Module): return final_embedding, final_attention_mask, final_labels, position_ids - def forward(self, - fbank: torch.Tensor = None, - input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor = None, - labels: torch.LongTensor = None, - ): + def forward( + self, + fbank: torch.Tensor = None, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor = None, + labels: torch.LongTensor = None, + ): encoder_outs = self.encoder(fbank) speech_features = self.encoder_projector(encoder_outs) - - inputs_embeds = self.llm.get_input_embeddings()(input_ids) - - enable_logging = False - rank = get_rank() - # log only on rank 0, training using deep - if enable_logging and rank == 0: - print("input_ids", input_ids, input_ids.shape) - print("labels", labels, labels.shape) - print("inputs_embeds", inputs_embeds.shape, inputs_embeds) - print("attention_mask_before", attention_mask.shape, attention_mask) - print(2333333333333333333333333333) - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features( + inputs_embeds = self.llm.get_input_embeddings()(input_ids) + + ( + inputs_embeds, + attention_mask, + labels, + _, + ) = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask, labels ) - if enable_logging and rank == 0: - print("speech_features", speech_features.shape, speech_features) - print("inputs_embeds after", inputs_embeds.shape, inputs_embeds) - print("attention_mask", attention_mask.shape, attention_mask) - print("position_ids", position_ids.shape, position_ids) - print("labels", labels, labels.shape) - print("================================================================") - model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) - # model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) + model_outputs = self.llm( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels + ) + with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) - if enable_logging and rank == 0: - print("preds", preds, preds.shape) - print(4555555555555555555555555555555555555555555) - acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID) + acc = compute_accuracy( + preds.detach()[:, :-1], + labels.detach()[:, 1:], + ignore_label=IGNORE_TOKEN_ID, + ) return model_outputs, acc - - def decode(self, - fbank: torch.Tensor = None, - input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor = None, - **kwargs - ): + def decode( + self, + fbank: torch.Tensor = None, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor = None, + **kwargs, + ): encoder_outs = self.encoder(fbank) speech_features = self.encoder_projector(encoder_outs) speech_features = speech_features.to(torch.float16) inputs_embeds = self.llm.get_input_embeddings()(input_ids) - inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features( + ( + inputs_embeds, + attention_mask, + _, + position_ids, + ) = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask ) generated_ids = self.llm.generate( @@ -189,7 +259,7 @@ class SPEECH_LLM(nn.Module): temperature=kwargs.get("temperature", 1.0), bos_token_id=self.llm.config.bos_token_id, eos_token_id=self.llm.config.eos_token_id, - pad_token_id=self.llm.config.pad_token_id + pad_token_id=self.llm.config.pad_token_id, ) return generated_ids @@ -197,7 +267,7 @@ class SPEECH_LLM(nn.Module): def compute_accuracy(pad_outputs, pad_targets, ignore_label): """Calculate accuracy. - + Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py Args: pad_outputs (LongTensor): Prediction tensors (B, Lmax). pad_targets (LongTensor): Target label tensors (B, Lmax). @@ -212,4 +282,4 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label): pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) ) denominator = torch.sum(mask) - return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type \ No newline at end of file + return numerator.float() / denominator.float() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py index abfa41b3f..eae967500 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py @@ -248,8 +248,6 @@ class MultiDataset: def aishell_train_cuts(self) -> CutSet: logging.info("About to get multidataset train cuts") - - # AISHELL-1 logging.info("Loading Aishell-1 in lazy mode") aishell_cuts = load_manifest_lazy( self.fbank_dir / "aishell_cuts_train.jsonl.gz" @@ -257,11 +255,8 @@ class MultiDataset: return aishell_cuts - def aishell_dev_cuts(self) -> CutSet: logging.info("About to get multidataset dev cuts") - - # AISHELL logging.info("Loading Aishell set in lazy mode") aishell_dev_cuts = load_manifest_lazy( self.fbank_dir / "aishell_cuts_dev.jsonl.gz" @@ -271,8 +266,6 @@ class MultiDataset: def aishell_test_cuts(self) -> CutSet: logging.info("About to get multidataset test cuts") - - # AISHELL logging.info("Loading Aishell set in lazy mode") aishell_test_cuts = load_manifest_lazy( self.fbank_dir / "aishell_cuts_test.jsonl.gz" @@ -282,12 +275,8 @@ class MultiDataset: "aishell_test": aishell_test_cuts, } - - # aishell 2 def aishell2_train_cuts(self) -> CutSet: logging.info("About to get multidataset train cuts") - - # AISHELL-2 logging.info("Loading Aishell-2 in lazy mode") aishell_2_cuts = load_manifest_lazy( self.fbank_dir / "aishell2_cuts_train.jsonl.gz" @@ -297,8 +286,6 @@ class MultiDataset: def aishell2_dev_cuts(self) -> CutSet: logging.info("About to get multidataset dev cuts") - - # AISHELL-2 logging.info("Loading Aishell-2 set in lazy mode") aishell2_dev_cuts = load_manifest_lazy( self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" @@ -308,8 +295,6 @@ class MultiDataset: def aishell2_test_cuts(self) -> CutSet: logging.info("About to get multidataset test cuts") - - # AISHELL-2 logging.info("Loading Aishell-2 set in lazy mode") aishell2_test_cuts = load_manifest_lazy( self.fbank_dir / "aishell2_cuts_test.jsonl.gz" @@ -321,8 +306,6 @@ class MultiDataset: def wenetspeech_test_meeting_cuts(self) -> CutSet: logging.info("About to get multidataset test cuts") - - # WeNetSpeech logging.info("Loading WeNetSpeech set in lazy mode") wenetspeech_test_meeting_cuts = load_manifest_lazy( self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" @@ -352,4 +335,4 @@ class MultiDataset: test_cuts = load_manifest_lazy(self.fbank_dir / path) results_dict[partition] = test_cuts - return results_dict \ No newline at end of file + return results_dict diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt old mode 100755 new mode 100644 index c5a90cb08..a07c7b157 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt @@ -5,9 +5,6 @@ sentencepiece pypinyin tensorboard librosa -# git+https://github.com/yuekaizhang/whisper.git -# zhconv -# WeTextProcessing deepspeed transformers>=4.37.0 flash-attn diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index f4d30d28a..5f224c984 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -17,14 +17,28 @@ # limitations under the License. """ Usage: +# fine-tuning with whisper and Qwen2 +pip install huggingface_hub['cli'] +mkdir -p models/whisper models/qwen -#fine-tuning with deepspeed zero stage 1 -torchrun --nproc-per-node 8 ./whisper/train.py \ +# For aishell fine-tuned whisper model +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +# For multi-hans fine-tuned whisper model +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + +# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct + +torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --max-duration 200 \ - --exp-dir whisper/exp_large_v2 \ - --model-name large-v2 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ + --manifest-dir data/fbank \ --deepspeed \ - --deepspeed_config ./whisper/ds_config_zero1.json + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True """ import argparse @@ -39,36 +53,29 @@ from typing import Any, Dict, List, Optional, Tuple, Union import deepspeed import k2 -# import optim import torch import torch.multiprocessing as mp import torch.nn as nn +import transformers import whisper from asr_datamodule import AsrDataModule -from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from label_smoothing import LabelSmoothingLoss from lhotse import CutSet, load_manifest from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from multi_dataset import MultiDataset -# from optim import Eden, ScaledAdam +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.functional import pad as pad_tensor -# from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter - +from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import update_averaged_model -from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist +from icefall.dist import get_rank, get_world_size from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, @@ -77,20 +84,15 @@ from icefall.utils import ( str2bool, ) -from transformers import AutoModelForCausalLM, AutoTokenizer -import transformers -from transformers.trainer_pt_utils import LabelSmoother - -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training - -#IGNORE_TOKEN_ID = LabelSmoother.ignore_index DEFAULT_SPEECH_TOKEN = "" + def set_batch_count(model: nn.Module, batch_count: float) -> None: for module in model.modules(): if hasattr(module, "batch_count"): module.batch_count = batch_count + def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--llm-path-or-name", @@ -109,7 +111,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-projector-ds-rate", type=int, - default=1, + default=8, help="Downsample rate for the encoder projector.", ) parser.add_argument( @@ -133,6 +135,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Whether to unfreeze llm during training.", ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -162,15 +165,6 @@ def get_parser(): """, ) - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - parser.add_argument( "--exp-dir", type=str, @@ -198,26 +192,6 @@ def get_parser(): """, ) - parser.add_argument( - "--base-lr", type=float, default=1e-5, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - parser.add_argument( "--seed", type=int, @@ -225,44 +199,6 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - parser.add_argument( "--use-fp16", type=str2bool, @@ -325,6 +261,7 @@ def get_params() -> AttributeDict: return params + def compute_loss( params: AttributeDict, tokenizer: AutoTokenizer, @@ -372,17 +309,23 @@ def compute_loss( tokenize=True, chat_template=TEMPLATE, add_generation_prompt=False, - padding="longest", # FIX me change padding to longest + padding="longest", # FIX me change padding to longest max_length=max_len, truncation=True, ) ) # padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id max_len_texts = max([len(text) for text in texts]) - if tokenizer.padding_side == 'right': - texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts] + if tokenizer.padding_side == "right": + texts = [ + text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + for text in texts + ] else: - texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts] + texts = [ + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text + for text in texts + ] input_ids = torch.tensor(texts, dtype=torch.int) # response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] target_ids = input_ids.clone() @@ -391,13 +334,14 @@ def compute_loss( # first get the indices of the tokens mask_prompt = True if mask_prompt: - mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant")) - # then mask all tokens before the first token e.g. 151646 (speech), 151645 , 198 \n + mask_indices = torch.where( + input_ids == tokenizer.convert_tokens_to_ids("assistant") + ) for i in range(mask_indices[0].size(0)): row = mask_indices[0][i] col = mask_indices[1][i] # + 2 to skip: 'assistant', '\n' - target_ids[row, :col+2] = IGNORE_TOKEN_ID + target_ids[row, : col + 2] = IGNORE_TOKEN_ID attention_mask = input_ids.ne(tokenizer.pad_token_id) @@ -458,20 +402,13 @@ def compute_loss( messages = [] for i, text in enumerate(texts): - # message = [ - # {"role": "system", "content": "你是一个能处理音频的助手。"}, - # {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"}, - # {"role": "assistant", "content": text}, - # ] message = [ - {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, - {"role": "assistant", "content": text}, + {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, + {"role": "assistant", "content": text}, ] messages.append(message) - input_ids, attention_mask, target_ids = preprocess( - messages, tokenizer, max_len=128 - ) + input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128) target_ids = target_ids.type(torch.LongTensor) input_ids = input_ids.type(torch.LongTensor) @@ -494,7 +431,9 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames + info["acc"] = ( + acc * info["frames"] + ) # WAR: to avoid normalization by the number of frames return loss, info @@ -607,7 +546,7 @@ def train_one_epoch( save_dir=params.exp_dir, tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", client_state={}, - exclude_frozen_parameters=True + exclude_frozen_parameters=True, ) if rank == 0: @@ -702,29 +641,26 @@ def run(rank, world_size, args): logging.info(params) logging.info("About to create model") - - # if 'whisper' in params.speech_encoder_path_or_name: + replace_whisper_encoder_forward() - # TODO: directly loading from whisper-ft checkpoint - # whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") speech_encoder = whisper_model.encoder speech_encoder_dim = whisper_model.dims.n_audio_state - for name, param in speech_encoder.named_parameters(): + for name, param in speech_encoder.named_parameters(): param.requires_grad = False speech_encoder.eval() tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) if params.use_flash_attn: attn_implementation = "flash_attention_2" - # torch_dtype=torch.bfloat16 - torch_dtype=torch.float16 - tokenizer.padding_side = 'left' + # torch_dtype=torch.bfloat16 FIX ME + torch_dtype = torch.float16 + tokenizer.padding_side = "left" else: attn_implementation = "eager" - torch_dtype=torch.float16 - tokenizer.padding_side = 'right' + torch_dtype = torch.float16 + tokenizer.padding_side = "right" llm = AutoModelForCausalLM.from_pretrained( params.llm_path_or_name, @@ -733,7 +669,7 @@ def run(rank, world_size, args): ) if not params.unfreeze_llm: - for name, param in llm.named_parameters(): + for name, param in llm.named_parameters(): param.requires_grad = False llm.eval() else: @@ -741,21 +677,31 @@ def run(rank, world_size, args): lora_config = LoraConfig( r=64, lora_alpha=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"], + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "up_proj", + "gate_proj", + "down_proj", + ], lora_dropout=0.05, task_type="CAUSAL_LM", ) llm = get_peft_model(llm, lora_config) llm.print_trainable_parameters() - special_tokens_dict = { - "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] - } + special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} tokenizer.add_special_tokens(special_tokens_dict) llm.config.pad_token_id = tokenizer.pad_token_id - llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) + llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( + DEFAULT_SPEECH_TOKEN + ) - encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate) + encoder_projector = EncoderProjector( + speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate + ) model = SPEECH_LLM( speech_encoder, @@ -806,7 +752,7 @@ def run(rank, world_size, args): # ) return False return True - + if params.use_aishell: train_cuts = multi_dataset.aishell_train_cuts() else: @@ -814,12 +760,6 @@ def run(rank, world_size, args): train_cuts = train_cuts.filter(remove_short_and_long_utt) - # if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # # We only load the sampler's state dict when it loads a checkpoint - # # saved in the middle of an epoch - # sampler_state_dict = checkpoints["sampler"] - # else: - # sampler_state_dict = None sampler_state_dict = None if params.sampler_state_dict_path: sampler_state_dict = torch.load(params.sampler_state_dict_path) @@ -840,13 +780,6 @@ def run(rank, world_size, args): else: tb_writer = None - # if params.pretrained_model_path: - # checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") - # if "model" not in checkpoint: - # model.load_state_dict(checkpoint, strict=True) - # else: - # load_checkpoint(params.pretrained_model_path, model) - logging.info(f"start training from epoch {params.start_epoch}") for epoch in range(params.start_epoch, params.num_epochs + 1): @@ -871,12 +804,11 @@ def run(rank, world_size, args): rank=rank, ) - model.save_checkpoint( save_dir=params.exp_dir, tag=f"epoch-{params.cur_epoch}", client_state={}, - exclude_frozen_parameters=True + exclude_frozen_parameters=True, ) if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( @@ -887,13 +819,16 @@ def run(rank, world_size, args): ) # save sampler state dict into checkpoint sampler_state_dict = train_dl.sampler.state_dict() - torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt") - + torch.save( + sampler_state_dict, + f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", + ) + os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") - logging.info("Done!") + def display_and_save_batch( batch: dict, params: AttributeDict,