diff --git a/egs/speech_llm/ASR_LLM/README.md b/egs/speech_llm/ASR_LLM/README.md
index 1e60c733c..171240db0 100644
--- a/egs/speech_llm/ASR_LLM/README.md
+++ b/egs/speech_llm/ASR_LLM/README.md
@@ -1,39 +1,20 @@
# Introduction
-This recipe includes scripts for training Zipformer model using multiple Chinese datasets.
+This recipe includes scripts for training [Qwen-Audio](https://github.com/QwenLM/Qwen-Audio/tree/main) style model using multiple datasets.
-# Included Training Sets
-1. THCHS-30
-2. AiShell-{1,2,4}
-3. ST-CMDS
-4. Primewords
-5. MagicData
-6. Aidatatang_200zh
-7. AliMeeting
-8. WeNetSpeech
-9. KeSpeech-ASR
+
+
+
+
+
-|Datset| Number of hours| URL|
-|---|---:|---|
-|**TOTAL**|14,106|---|
-|THCHS-30|35|https://www.openslr.org/18/|
-|AiShell-1|170|https://www.openslr.org/33/|
-|AiShell-2|1,000|http://www.aishelltech.com/aishell_2|
-|AiShell-4|120|https://www.openslr.org/111/|
-|ST-CMDS|110|https://www.openslr.org/38/|
-|Primewords|99|https://www.openslr.org/47/|
-|aidatatang_200zh|200|https://www.openslr.org/62/|
-|MagicData|755|https://www.openslr.org/68/|
-|AliMeeting|100|https://openslr.org/119/|
-|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech|
-|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech|
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+# ASR_LLM
-# Included Test Sets
-1. Aishell-{1,2,4}
-2. Aidatatang_200zh
-3. AliMeeting
-4. MagicData
-5. KeSpeech-ASR
-6. WeNetSpeech
+The following table lists the folders for different tasks.
+
+| | Speech Encoder | LLM | Comment |
+|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
+| [whisper_llm_zh](./whisper_llm_zh) | Whisper | Qwen2 | [Using multiple Chinese datasets](https://github.com/k2-fsa/icefall/tree/master/egs/multi_zh-hans/ASR) |
diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md
index a7f3bc4f7..773824c05 100644
--- a/egs/speech_llm/ASR_LLM/RESULTS.md
+++ b/egs/speech_llm/ASR_LLM/RESULTS.md
@@ -1,116 +1,62 @@
## Results
-### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
-#### Whisper
-[./whisper](./whisper)
+### whisper_llm_zh finetuning results
-Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.
-
-| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
-|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
-| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
-| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
+| Training Dataset | Speech Encoder | LLM | Projector |Comment | CER |
+| -------------------------| ----------------|------|--------------------------------------------------|-----|--|
+| Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 Test 3.76% |
+
Command for training is:
```bash
-pip install -r whisper/requirements.txt
+pip install -r whisper_llm_zh/requirements.txt
-# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54
+pip install huggingface_hub['cli']
+mkdir -p models/whisper models/qwen
-torchrun --nproc-per-node 8 ./whisper/train.py \
+# For aishell fine-tuned whisper model
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
+# For multi-hans fine-tuned whisper model
+# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
+
+# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
+huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
+
+torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
- --exp-dir whisper/exp_large_v2 \
- --model-name large-v2 \
+ --exp-dir ./whisper_llm_zh/exp_test \
+ --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
+ --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
+ --manifest-dir data/fbank \
--deepspeed \
- --deepspeed_config ./whisper/ds_config_zero1.json
+ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
+ --use-flash-attn True \
+ --use-lora True --unfreeze-llm True
```
Command for decoding using fine-tuned models:
```bash
-git lfs install
-git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
-ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt
+mkdir -p models/whisper models/qwen models/checkpoint
+huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
-python3 ./whisper/decode.py \
- --exp-dir whisper/exp_large_v2 \
- --model-name large-v2 \
+# For aishell fine-tuned whisper model
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
+# For multi-hans fine-tuned whisper model
+# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
+
+huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
+
+mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
+ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
+
+python3 ./whisper_llm_zh/decode.py \
+ --max-duration 80 \
+ --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
+ --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
+ --llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
- --beam-size 10 --max-duration 50
+ --manifest-dir data/fbank \
+ --use-flash-attn True \
+ --use-lora True --dataset aishell
```
-
-Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
-are available at
-
-
-
-### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
-
-This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
-
-#### Non-streaming (with CTC head)
-
-Best results (num of params : ~69M):
-
-The training command:
-
-```
-./zipformer/train.py \
- --world-size 4 \
- --num-epochs 20 \
- --use-fp16 1 \
- --max-duration 600 \
- --num-workers 8 \
- --use-ctc 1
-```
-
-The decoding command:
-
-```
-./zipformer/decode.py \
- --epoch 20 \
- --avg 1 \
- --use-ctc 1
-```
-
-Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled).
-
-| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
-|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
-| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
-| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 |
-| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 |
-
-Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/
-
-#### Non-streaming
-
-Best results (num of params : ~69M):
-
-The training command:
-
-```
-./zipformer/train.py \
- --world-size 4 \
- --num-epochs 20 \
- --use-fp16 1 \
- --max-duration 600 \
- --num-workers 8
-```
-
-The decoding command:
-
-```
-./zipformer/decode.py \
- --epoch 20 \
- --avg 1
-```
-
-Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled).
-
-| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
-|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
-| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
-| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |
-
-
-Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
diff --git a/egs/speech_llm/ASR_LLM/assets/framework.png b/egs/speech_llm/ASR_LLM/assets/framework.png
new file mode 100644
index 000000000..dc48bda78
Binary files /dev/null and b/egs/speech_llm/ASR_LLM/assets/framework.png differ
diff --git a/egs/speech_llm/ASR_LLM/prepare.sh b/egs/speech_llm/ASR_LLM/prepare.sh
new file mode 100644
index 000000000..6f5ed5448
--- /dev/null
+++ b/egs/speech_llm/ASR_LLM/prepare.sh
@@ -0,0 +1,46 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=0
+stop_stage=0
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "stage 0: Download whisper-large-v2 aishell 1 fbank feature from huggingface"
+
+ # pip install huggingface_hub['cli']
+ # for aishell 1
+ huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse
+
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
+
+ # for multi-hans-zh
+ huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
+ huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
+ huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "stage 2: Download whisper-large-v2 speechio test sets fbank feature from huggingface"
+
+ # for speechio test sets
+ mkdir data_speechio
+ huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio
+ mv data_speechio/fbank/* data/fbank
+fi
diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py
index f3bdb452c..882ce4fbf 100755
--- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py
+++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py
@@ -20,21 +20,34 @@
"""
Usage:
# Command for decoding using fine-tuned models:
-git lfs install
-git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
-ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
-python3 ./whisper/decode.py \
- --exp-dir whisper/exp_large_v2 \
- --model-name large-v2 \
+pip install huggingface_hub['cli']
+mkdir -p models/whisper models/qwen models/checkpoint
+huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
+
+# For aishell fine-tuned whisper model
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
+# For multi-hans fine-tuned whisper model
+# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
+
+huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
+
+mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
+ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
+
+python3 ./whisper_llm_zh/decode.py \
+ --max-duration 80 \
+ --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
+ --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
+ --llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
- --beam-size 10 --max-duration 50
-
+ --manifest-dir data/fbank \
+ --use-flash-attn True \
+ --use-lora True --dataset aishell
"""
import argparse
import logging
-import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@@ -42,18 +55,17 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
+import transformers
import whisper
from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
+from model import SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
-#from tn.chinese.normalizer import Normalizer
-#from whisper.normalizers import BasicTextNormalizer
-#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
-from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
-#from zhconv import convert
-import transformers
+from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
+from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer
-from model import EncoderProjector, SPEECH_LLM
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
@@ -63,8 +75,7 @@ from icefall.utils import (
str2bool,
write_error_stats,
)
-from train import DEFAULT_SPEECH_TOKEN
-from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
+
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
@@ -117,6 +128,7 @@ def average_checkpoints(
return avg
+
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
@@ -135,7 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
- default=1,
+ default=8,
help="Downsample rate for the encoder projector.",
)
@@ -149,10 +161,11 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--use-lora",
type=str2bool,
- default=False,
- help="Whether to use lora to fine-tune llm.",
+ default=True,
+ help="Whether to use lora fine-tuned llm checkpoint.",
)
+
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -247,6 +260,7 @@ def decode_one_batch(
Returns:
Return a dict, whose key may be "beam-search".
"""
+
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
@@ -268,10 +282,16 @@ def decode_one_batch(
)
)
max_len_texts = max([len(text) for text in texts])
- if tokenizer.padding_side == 'right':
- texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
else:
- texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
input_ids = torch.tensor(texts, dtype=torch.int)
@@ -302,16 +322,18 @@ def decode_one_batch(
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
- messages = [[
- {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
- {"role": "assistant", "content": ""},
- ]] * len(feature)
+ messages = [
+ [
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
+ {"role": "assistant", "content": ""},
+ ]
+ ] * len(feature)
- input_ids, attention_mask = preprocess(
- messages, tokenizer, max_len=128
+ input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
+
+ generated_ids = model.decode(
+ feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
-
- generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return {"beam-search": hyps}
@@ -497,14 +519,14 @@ def main():
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
- # torch_dtype=torch.bfloat16
- torch_dtype=torch.float16
- tokenizer.padding_side = 'left'
+ # torch_dtype=torch.bfloat16 FIX ME
+ torch_dtype = torch.float16
+ tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
- torch_dtype=torch.float16
- tokenizer.padding_side = 'right'
+ torch_dtype = torch.float16
+ tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
@@ -515,23 +537,33 @@ def main():
lora_config = LoraConfig(
r=64,
lora_alpha=16,
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
+ target_modules=[
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "up_proj",
+ "gate_proj",
+ "down_proj",
+ ],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
- special_tokens_dict = {
- "additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
- }
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
- llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
+ llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
+ DEFAULT_SPEECH_TOKEN
+ )
- encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
+ encoder_projector = EncoderProjector(
+ speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
+ )
model = SPEECH_LLM(
speech_encoder,
@@ -539,7 +571,6 @@ def main():
encoder_projector,
)
-
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
@@ -579,7 +610,7 @@ def main():
#
if c.duration > 30.0:
logging.warning(
- f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
return True
diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py
deleted file mode 120000
index e9d239fff..000000000
--- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py
index 440724db2..829ef4e2d 100644
--- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py
+++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py
@@ -1,11 +1,20 @@
-from torch import nn
import torch
+from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
-from icefall.dist import get_rank
+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
+
class EncoderProjector(nn.Module):
-# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
+ """
+ The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
+ Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
+ Args:
+ encoder_dim (:obj:`int`): The dimension of the encoder outputs.
+ llm_dim (:obj:`int`): The dimension of the language model.
+ downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
+ """
+
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__()
self.downsample_rate = downsample_rate
@@ -20,16 +29,30 @@ class EncoderProjector(nn.Module):
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
-
+
x = x.contiguous()
- x = x.view(batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate)
+ x = x.view(
+ batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
+ )
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
+
class SPEECH_LLM(nn.Module):
+ """
+ The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
+ The encoder is used to extract speech features from the input speech signal.
+ The encoder projector is used to project the encoder outputs to the same dimension as the language model.
+ The language model is used to generate the text from the speech features.
+ Args:
+ encoder (:obj:`nn.Module`): The encoder module.
+ llm (:obj:`nn.Module`): The language model module.
+ encoder_projector (:obj:`nn.Module`): The encoder projector module.
+ """
+
def __init__(
self,
encoder: nn.Module,
@@ -41,23 +64,46 @@ class SPEECH_LLM(nn.Module):
self.llm = llm
self.encoder_projector = encoder_projector
- def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
+ def _merge_input_ids_with_speech_features(
+ self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
+ ):
+ """
+ Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
+ with the speech features and padding the input_ids to the maximum length of the speech features.
+ Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
+ Args:
+ speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
+ inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
+ input_ids (:obj:`torch.Tensor`): The input ids to merge.
+ attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
+ labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
+ Returns:
+ :obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
+ """
num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape
- left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
+ left_padding = not torch.sum(
+ input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
+ )
# 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension
- max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length
- batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id)
+ max_embed_dim = (
+ num_special_speech_tokens.max() * (speech_len - 1)
+ ) + sequence_length
+ batch_indices, non_speech_indices = torch.where(
+ input_ids != self.llm.config.default_speech_token_id
+ )
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence.
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
- new_token_positions = torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
+ new_token_positions = (
+ torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
+ )
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_speech_pad[:, None] # offset for left padding
@@ -65,14 +111,24 @@ class SPEECH_LLM(nn.Module):
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ batch_size,
+ max_embed_dim,
+ embed_dim,
+ dtype=inputs_embeds.dtype,
+ device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
+ batch_size,
+ max_embed_dim,
+ dtype=attention_mask.dtype,
+ device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
- (batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device
+ (batch_size, max_embed_dim),
+ IGNORE_TOKEN_ID,
+ dtype=input_ids.dtype,
+ device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
@@ -86,17 +142,28 @@ class SPEECH_LLM(nn.Module):
# 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_speech_indices]
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_speech_indices]
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
+ batch_indices, non_speech_indices
+ ]
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
+ batch_indices, non_speech_indices
+ ]
if labels is not None:
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_speech_indices]
+ final_labels[batch_indices, text_to_overwrite] = labels[
+ batch_indices, non_speech_indices
+ ]
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
speech_to_overwrite = torch.full(
- (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
+ (batch_size, max_embed_dim),
+ True,
+ dtype=torch.bool,
+ device=inputs_embeds.device,
)
speech_to_overwrite[batch_indices, text_to_overwrite] = False
- speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[:, None].to(target_device)
+ speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
+ :, None
+ ].to(target_device)
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
raise ValueError(
@@ -104,12 +171,18 @@ class SPEECH_LLM(nn.Module):
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
)
- final_embedding[speech_to_overwrite] = speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
+ final_embedding[speech_to_overwrite] = (
+ speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
+ )
final_attention_mask |= speech_to_overwrite
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
+ (final_attention_mask == 0), 1
+ )
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
- batch_indices, pad_indices = torch.where(input_ids == self.llm.config.pad_token_id)
+ batch_indices, pad_indices = torch.where(
+ input_ids == self.llm.config.pad_token_id
+ )
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
@@ -119,62 +192,59 @@ class SPEECH_LLM(nn.Module):
return final_embedding, final_attention_mask, final_labels, position_ids
- def forward(self,
- fbank: torch.Tensor = None,
- input_ids: torch.LongTensor = None,
- attention_mask: torch.Tensor = None,
- labels: torch.LongTensor = None,
- ):
+ def forward(
+ self,
+ fbank: torch.Tensor = None,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor = None,
+ labels: torch.LongTensor = None,
+ ):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
-
- inputs_embeds = self.llm.get_input_embeddings()(input_ids)
-
- enable_logging = False
- rank = get_rank()
- # log only on rank 0, training using deep
- if enable_logging and rank == 0:
- print("input_ids", input_ids, input_ids.shape)
- print("labels", labels, labels.shape)
- print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
- print("attention_mask_before", attention_mask.shape, attention_mask)
- print(2333333333333333333333333333)
- inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
+
+ (
+ inputs_embeds,
+ attention_mask,
+ labels,
+ _,
+ ) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
- if enable_logging and rank == 0:
- print("speech_features", speech_features.shape, speech_features)
- print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
- print("attention_mask", attention_mask.shape, attention_mask)
- print("position_ids", position_ids.shape, position_ids)
- print("labels", labels, labels.shape)
- print("================================================================")
- model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
- # model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
+ )
+
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
- if enable_logging and rank == 0:
- print("preds", preds, preds.shape)
- print(4555555555555555555555555555555555555555555)
- acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
+ acc = compute_accuracy(
+ preds.detach()[:, :-1],
+ labels.detach()[:, 1:],
+ ignore_label=IGNORE_TOKEN_ID,
+ )
return model_outputs, acc
-
- def decode(self,
- fbank: torch.Tensor = None,
- input_ids: torch.LongTensor = None,
- attention_mask: torch.Tensor = None,
- **kwargs
- ):
+ def decode(
+ self,
+ fbank: torch.Tensor = None,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor = None,
+ **kwargs,
+ ):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
- inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
+ (
+ inputs_embeds,
+ attention_mask,
+ _,
+ position_ids,
+ ) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
generated_ids = self.llm.generate(
@@ -189,7 +259,7 @@ class SPEECH_LLM(nn.Module):
temperature=kwargs.get("temperature", 1.0),
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
- pad_token_id=self.llm.config.pad_token_id
+ pad_token_id=self.llm.config.pad_token_id,
)
return generated_ids
@@ -197,7 +267,7 @@ class SPEECH_LLM(nn.Module):
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
-
+ Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
@@ -212,4 +282,4 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
- return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
\ No newline at end of file
+ return numerator.float() / denominator.float()
diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py
index abfa41b3f..eae967500 100644
--- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py
+++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py
@@ -248,8 +248,6 @@ class MultiDataset:
def aishell_train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
-
- # AISHELL-1
logging.info("Loading Aishell-1 in lazy mode")
aishell_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
@@ -257,11 +255,8 @@ class MultiDataset:
return aishell_cuts
-
def aishell_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
-
- # AISHELL
logging.info("Loading Aishell set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
@@ -271,8 +266,6 @@ class MultiDataset:
def aishell_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
-
- # AISHELL
logging.info("Loading Aishell set in lazy mode")
aishell_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
@@ -282,12 +275,8 @@ class MultiDataset:
"aishell_test": aishell_test_cuts,
}
-
- # aishell 2
def aishell2_train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
-
- # AISHELL-2
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
@@ -297,8 +286,6 @@ class MultiDataset:
def aishell2_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
-
- # AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
@@ -308,8 +295,6 @@ class MultiDataset:
def aishell2_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
-
- # AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
@@ -321,8 +306,6 @@ class MultiDataset:
def wenetspeech_test_meeting_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
-
- # WeNetSpeech
logging.info("Loading WeNetSpeech set in lazy mode")
wenetspeech_test_meeting_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
@@ -352,4 +335,4 @@ class MultiDataset:
test_cuts = load_manifest_lazy(self.fbank_dir / path)
results_dict[partition] = test_cuts
- return results_dict
\ No newline at end of file
+ return results_dict
diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt
old mode 100755
new mode 100644
index c5a90cb08..a07c7b157
--- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt
+++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt
@@ -5,9 +5,6 @@ sentencepiece
pypinyin
tensorboard
librosa
-# git+https://github.com/yuekaizhang/whisper.git
-# zhconv
-# WeTextProcessing
deepspeed
transformers>=4.37.0
flash-attn
diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py
index f4d30d28a..5f224c984 100755
--- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py
+++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py
@@ -17,14 +17,28 @@
# limitations under the License.
"""
Usage:
+# fine-tuning with whisper and Qwen2
+pip install huggingface_hub['cli']
+mkdir -p models/whisper models/qwen
-#fine-tuning with deepspeed zero stage 1
-torchrun --nproc-per-node 8 ./whisper/train.py \
+# For aishell fine-tuned whisper model
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
+# For multi-hans fine-tuned whisper model
+# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
+
+# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
+huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
+
+torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
- --exp-dir whisper/exp_large_v2 \
- --model-name large-v2 \
+ --exp-dir ./whisper_llm_zh/exp_test \
+ --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
+ --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
+ --manifest-dir data/fbank \
--deepspeed \
- --deepspeed_config ./whisper/ds_config_zero1.json
+ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
+ --use-flash-attn True \
+ --use-lora True --unfreeze-llm True
"""
import argparse
@@ -39,36 +53,29 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import k2
-# import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
+import transformers
import whisper
from asr_datamodule import AsrDataModule
-from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
+from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
-# from optim import Eden, ScaledAdam
+from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch import Tensor
-from torch.cuda.amp import GradScaler
-from torch.nn.functional import pad as pad_tensor
-# from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
-
+from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
-from icefall.checkpoint import load_checkpoint, remove_checkpoints
-from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
-from icefall.checkpoint import update_averaged_model
-from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
+from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info
-from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
@@ -77,20 +84,15 @@ from icefall.utils import (
str2bool,
)
-from transformers import AutoModelForCausalLM, AutoTokenizer
-import transformers
-from transformers.trainer_pt_utils import LabelSmoother
-
-from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
-
-#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = ""
+
def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
+
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
@@ -109,7 +111,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
- default=1,
+ default=8,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
@@ -133,6 +135,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to unfreeze llm during training.",
)
+
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -162,15 +165,6 @@ def get_parser():
""",
)
- parser.add_argument(
- "--start-batch",
- type=int,
- default=0,
- help="""If positive, --start-epoch is ignored and
- it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
- """,
- )
-
parser.add_argument(
"--exp-dir",
type=str,
@@ -198,26 +192,6 @@ def get_parser():
""",
)
- parser.add_argument(
- "--base-lr", type=float, default=1e-5, help="The base learning rate."
- )
-
- parser.add_argument(
- "--lr-batches",
- type=float,
- default=5000,
- help="""Number of steps that affects how rapidly the learning rate
- decreases. We suggest not to change this.""",
- )
-
- parser.add_argument(
- "--lr-epochs",
- type=float,
- default=6,
- help="""Number of epochs that affects how rapidly the learning rate decreases.
- """,
- )
-
parser.add_argument(
"--seed",
type=int,
@@ -225,44 +199,6 @@ def get_parser():
help="The seed for random generators intended for reproducibility",
)
- parser.add_argument(
- "--print-diagnostics",
- type=str2bool,
- default=False,
- help="Accumulate stats on activations, print them and exit.",
- )
-
- parser.add_argument(
- "--inf-check",
- type=str2bool,
- default=False,
- help="Add hooks to check for infinite module outputs and gradients.",
- )
-
- parser.add_argument(
- "--keep-last-k",
- type=int,
- default=30,
- help="""Only keep this number of checkpoints on disk.
- For instance, if it is 3, there are only 3 checkpoints
- in the exp-dir with filenames `checkpoint-xxx.pt`.
- It does not affect checkpoints with name `epoch-xxx.pt`.
- """,
- )
-
- parser.add_argument(
- "--average-period",
- type=int,
- default=200,
- help="""Update the averaged model, namely `model_avg`, after processing
- this number of batches. `model_avg` is a separate version of model,
- in which each floating-point parameter is the average of all the
- parameters from the start of training. Each time we take the average,
- we do: `model_avg = model * (average_period / batch_idx_train) +
- model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
- """,
- )
-
parser.add_argument(
"--use-fp16",
type=str2bool,
@@ -325,6 +261,7 @@ def get_params() -> AttributeDict:
return params
+
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
@@ -372,17 +309,23 @@ def compute_loss(
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
- padding="longest", # FIX me change padding to longest
+ padding="longest", # FIX me change padding to longest
max_length=max_len,
truncation=True,
)
)
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
max_len_texts = max([len(text) for text in texts])
- if tokenizer.padding_side == 'right':
- texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
+ if tokenizer.padding_side == "right":
+ texts = [
+ text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
+ for text in texts
+ ]
else:
- texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
+ texts = [
+ [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
+ for text in texts
+ ]
input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone()
@@ -391,13 +334,14 @@ def compute_loss(
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
- mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
- # then mask all tokens before the first token e.g. 151646 (speech), 151645 , 198 \n
+ mask_indices = torch.where(
+ input_ids == tokenizer.convert_tokens_to_ids("assistant")
+ )
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
- target_ids[row, :col+2] = IGNORE_TOKEN_ID
+ target_ids[row, : col + 2] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
@@ -458,20 +402,13 @@ def compute_loss(
messages = []
for i, text in enumerate(texts):
- # message = [
- # {"role": "system", "content": "你是一个能处理音频的助手。"},
- # {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
- # {"role": "assistant", "content": text},
- # ]
message = [
- {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
- {"role": "assistant", "content": text},
+ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
+ {"role": "assistant", "content": text},
]
messages.append(message)
- input_ids, attention_mask, target_ids = preprocess(
- messages, tokenizer, max_len=128
- )
+ input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
@@ -494,7 +431,9 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
- info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames
+ info["acc"] = (
+ acc * info["frames"]
+ ) # WAR: to avoid normalization by the number of frames
return loss, info
@@ -607,7 +546,7 @@ def train_one_epoch(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={},
- exclude_frozen_parameters=True
+ exclude_frozen_parameters=True,
)
if rank == 0:
@@ -702,29 +641,26 @@ def run(rank, world_size, args):
logging.info(params)
logging.info("About to create model")
-
- # if 'whisper' in params.speech_encoder_path_or_name:
+
replace_whisper_encoder_forward()
- # TODO: directly loading from whisper-ft checkpoint
- # whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
- for name, param in speech_encoder.named_parameters():
+ for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
- # torch_dtype=torch.bfloat16
- torch_dtype=torch.float16
- tokenizer.padding_side = 'left'
+ # torch_dtype=torch.bfloat16 FIX ME
+ torch_dtype = torch.float16
+ tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
- torch_dtype=torch.float16
- tokenizer.padding_side = 'right'
+ torch_dtype = torch.float16
+ tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
@@ -733,7 +669,7 @@ def run(rank, world_size, args):
)
if not params.unfreeze_llm:
- for name, param in llm.named_parameters():
+ for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
else:
@@ -741,21 +677,31 @@ def run(rank, world_size, args):
lora_config = LoraConfig(
r=64,
lora_alpha=16,
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
+ target_modules=[
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "up_proj",
+ "gate_proj",
+ "down_proj",
+ ],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
- special_tokens_dict = {
- "additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
- }
+ special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
- llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
+ llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
+ DEFAULT_SPEECH_TOKEN
+ )
- encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
+ encoder_projector = EncoderProjector(
+ speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
+ )
model = SPEECH_LLM(
speech_encoder,
@@ -806,7 +752,7 @@ def run(rank, world_size, args):
# )
return False
return True
-
+
if params.use_aishell:
train_cuts = multi_dataset.aishell_train_cuts()
else:
@@ -814,12 +760,6 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
- # if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
- # # We only load the sampler's state dict when it loads a checkpoint
- # # saved in the middle of an epoch
- # sampler_state_dict = checkpoints["sampler"]
- # else:
- # sampler_state_dict = None
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
@@ -840,13 +780,6 @@ def run(rank, world_size, args):
else:
tb_writer = None
- # if params.pretrained_model_path:
- # checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
- # if "model" not in checkpoint:
- # model.load_state_dict(checkpoint, strict=True)
- # else:
- # load_checkpoint(params.pretrained_model_path, model)
-
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
@@ -871,12 +804,11 @@ def run(rank, world_size, args):
rank=rank,
)
-
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
client_state={},
- exclude_frozen_parameters=True
+ exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
@@ -887,13 +819,16 @@ def run(rank, world_size, args):
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
- torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt")
-
+ torch.save(
+ sampler_state_dict,
+ f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
+ )
+
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
-
logging.info("Done!")
+
def display_and_save_batch(
batch: dict,
params: AttributeDict,