mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add readme
This commit is contained in:
parent
d1e31c7ac7
commit
9ed428d7b1
@ -1,39 +1,20 @@
|
|||||||
|
|
||||||
# Introduction
|
# Introduction
|
||||||
|
|
||||||
This recipe includes scripts for training Zipformer model using multiple Chinese datasets.
|
This recipe includes scripts for training [Qwen-Audio](https://github.com/QwenLM/Qwen-Audio/tree/main) style model using multiple datasets.
|
||||||
|
|
||||||
# Included Training Sets
|
<br>
|
||||||
1. THCHS-30
|
<p align="center">
|
||||||
2. AiShell-{1,2,4}
|
<img src="assets/framework.png" width="800"/>
|
||||||
3. ST-CMDS
|
<p>
|
||||||
4. Primewords
|
<br>
|
||||||
5. MagicData
|
|
||||||
6. Aidatatang_200zh
|
|
||||||
7. AliMeeting
|
|
||||||
8. WeNetSpeech
|
|
||||||
9. KeSpeech-ASR
|
|
||||||
|
|
||||||
|Datset| Number of hours| URL|
|
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||||
|---|---:|---|
|
|
||||||
|**TOTAL**|14,106|---|
|
|
||||||
|THCHS-30|35|https://www.openslr.org/18/|
|
|
||||||
|AiShell-1|170|https://www.openslr.org/33/|
|
|
||||||
|AiShell-2|1,000|http://www.aishelltech.com/aishell_2|
|
|
||||||
|AiShell-4|120|https://www.openslr.org/111/|
|
|
||||||
|ST-CMDS|110|https://www.openslr.org/38/|
|
|
||||||
|Primewords|99|https://www.openslr.org/47/|
|
|
||||||
|aidatatang_200zh|200|https://www.openslr.org/62/|
|
|
||||||
|MagicData|755|https://www.openslr.org/68/|
|
|
||||||
|AliMeeting|100|https://openslr.org/119/|
|
|
||||||
|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech|
|
|
||||||
|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech|
|
|
||||||
|
|
||||||
|
# ASR_LLM
|
||||||
|
|
||||||
# Included Test Sets
|
The following table lists the folders for different tasks.
|
||||||
1. Aishell-{1,2,4}
|
|
||||||
2. Aidatatang_200zh
|
| | Speech Encoder | LLM | Comment |
|
||||||
3. AliMeeting
|
|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
|
||||||
4. MagicData
|
| [whisper_llm_zh](./whisper_llm_zh) | Whisper | Qwen2 | [Using multiple Chinese datasets](https://github.com/k2-fsa/icefall/tree/master/egs/multi_zh-hans/ASR) |
|
||||||
5. KeSpeech-ASR
|
|
||||||
6. WeNetSpeech
|
|
||||||
|
@ -1,116 +1,62 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
|
### whisper_llm_zh finetuning results
|
||||||
#### Whisper
|
|
||||||
[./whisper](./whisper)
|
|
||||||
|
|
||||||
Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.
|
| Training Dataset | Speech Encoder | LLM | Projector |Comment | CER |
|
||||||
|
| -------------------------| ----------------|------|--------------------------------------------------|-----|--|
|
||||||
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|
| Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 Test 3.76% |
|
||||||
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
|
<!-- | Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| WIP ||
|
||||||
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
|
| Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-7B-Instruct, LoRA | Linear, 8x downsample| WIP || -->
|
||||||
| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
|
|
||||||
|
|
||||||
Command for training is:
|
Command for training is:
|
||||||
```bash
|
```bash
|
||||||
pip install -r whisper/requirements.txt
|
pip install -r whisper_llm_zh/requirements.txt
|
||||||
|
|
||||||
# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54
|
pip install huggingface_hub['cli']
|
||||||
|
mkdir -p models/whisper models/qwen
|
||||||
|
|
||||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
# For aishell fine-tuned whisper model
|
||||||
|
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||||
|
# For multi-hans fine-tuned whisper model
|
||||||
|
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||||
|
|
||||||
|
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||||
|
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
||||||
|
|
||||||
|
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--exp-dir whisper/exp_large_v2 \
|
--exp-dir ./whisper_llm_zh/exp_test \
|
||||||
--model-name large-v2 \
|
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||||
|
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--use-lora True --unfreeze-llm True
|
||||||
```
|
```
|
||||||
|
|
||||||
Command for decoding using fine-tuned models:
|
Command for decoding using fine-tuned models:
|
||||||
```bash
|
```bash
|
||||||
git lfs install
|
mkdir -p models/whisper models/qwen models/checkpoint
|
||||||
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
|
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||||
ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt
|
|
||||||
|
|
||||||
python3 ./whisper/decode.py \
|
# For aishell fine-tuned whisper model
|
||||||
--exp-dir whisper/exp_large_v2 \
|
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||||
--model-name large-v2 \
|
# For multi-hans fine-tuned whisper model
|
||||||
|
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||||
|
|
||||||
|
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||||
|
|
||||||
|
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||||
|
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||||
|
|
||||||
|
python3 ./whisper_llm_zh/decode.py \
|
||||||
|
--max-duration 80 \
|
||||||
|
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
||||||
|
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||||
|
--llm-path-or-name models/qwen \
|
||||||
--epoch 999 --avg 1 \
|
--epoch 999 --avg 1 \
|
||||||
--beam-size 10 --max-duration 50
|
--manifest-dir data/fbank \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--use-lora True --dataset aishell
|
||||||
```
|
```
|
||||||
|
|
||||||
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
|
|
||||||
are available at
|
|
||||||
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
|
|
||||||
|
|
||||||
|
|
||||||
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
|
|
||||||
|
|
||||||
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
|
|
||||||
|
|
||||||
#### Non-streaming (with CTC head)
|
|
||||||
|
|
||||||
Best results (num of params : ~69M):
|
|
||||||
|
|
||||||
The training command:
|
|
||||||
|
|
||||||
```
|
|
||||||
./zipformer/train.py \
|
|
||||||
--world-size 4 \
|
|
||||||
--num-epochs 20 \
|
|
||||||
--use-fp16 1 \
|
|
||||||
--max-duration 600 \
|
|
||||||
--num-workers 8 \
|
|
||||||
--use-ctc 1
|
|
||||||
```
|
|
||||||
|
|
||||||
The decoding command:
|
|
||||||
|
|
||||||
```
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 1 \
|
|
||||||
--use-ctc 1
|
|
||||||
```
|
|
||||||
|
|
||||||
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled).
|
|
||||||
|
|
||||||
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
|
||||||
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
|
||||||
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
|
||||||
| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 |
|
|
||||||
| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 |
|
|
||||||
|
|
||||||
Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/
|
|
||||||
|
|
||||||
#### Non-streaming
|
|
||||||
|
|
||||||
Best results (num of params : ~69M):
|
|
||||||
|
|
||||||
The training command:
|
|
||||||
|
|
||||||
```
|
|
||||||
./zipformer/train.py \
|
|
||||||
--world-size 4 \
|
|
||||||
--num-epochs 20 \
|
|
||||||
--use-fp16 1 \
|
|
||||||
--max-duration 600 \
|
|
||||||
--num-workers 8
|
|
||||||
```
|
|
||||||
|
|
||||||
The decoding command:
|
|
||||||
|
|
||||||
```
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 1
|
|
||||||
```
|
|
||||||
|
|
||||||
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled).
|
|
||||||
|
|
||||||
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
|
||||||
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
|
||||||
| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
|
||||||
| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |
|
|
||||||
|
|
||||||
|
|
||||||
Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
|
|
||||||
|
BIN
egs/speech_llm/ASR_LLM/assets/framework.png
Normal file
BIN
egs/speech_llm/ASR_LLM/assets/framework.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 834 KiB |
46
egs/speech_llm/ASR_LLM/prepare.sh
Normal file
46
egs/speech_llm/ASR_LLM/prepare.sh
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=0
|
||||||
|
# All files generated by this script are saved in "data".
|
||||||
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
|
mkdir -p data
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "stage 0: Download whisper-large-v2 aishell 1 fbank feature from huggingface"
|
||||||
|
|
||||||
|
# pip install huggingface_hub['cli']
|
||||||
|
# for aishell 1
|
||||||
|
huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
|
||||||
|
|
||||||
|
# for multi-hans-zh
|
||||||
|
huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
|
||||||
|
huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
|
||||||
|
huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "stage 2: Download whisper-large-v2 speechio test sets fbank feature from huggingface"
|
||||||
|
|
||||||
|
# for speechio test sets
|
||||||
|
mkdir data_speechio
|
||||||
|
huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio
|
||||||
|
mv data_speechio/fbank/* data/fbank
|
||||||
|
fi
|
@ -20,21 +20,34 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
# Command for decoding using fine-tuned models:
|
# Command for decoding using fine-tuned models:
|
||||||
git lfs install
|
|
||||||
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
|
|
||||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
|
|
||||||
|
|
||||||
python3 ./whisper/decode.py \
|
pip install huggingface_hub['cli']
|
||||||
--exp-dir whisper/exp_large_v2 \
|
mkdir -p models/whisper models/qwen models/checkpoint
|
||||||
--model-name large-v2 \
|
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||||
|
|
||||||
|
# For aishell fine-tuned whisper model
|
||||||
|
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||||
|
# For multi-hans fine-tuned whisper model
|
||||||
|
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||||
|
|
||||||
|
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||||
|
|
||||||
|
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||||
|
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||||
|
|
||||||
|
python3 ./whisper_llm_zh/decode.py \
|
||||||
|
--max-duration 80 \
|
||||||
|
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
||||||
|
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||||
|
--llm-path-or-name models/qwen \
|
||||||
--epoch 999 --avg 1 \
|
--epoch 999 --avg 1 \
|
||||||
--beam-size 10 --max-duration 50
|
--manifest-dir data/fbank \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--use-lora True --dataset aishell
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -42,18 +55,17 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
import whisper
|
import whisper
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
|
from model import SPEECH_LLM, EncoderProjector
|
||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
#from tn.chinese.normalizer import Normalizer
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
#from whisper.normalizers import BasicTextNormalizer
|
from train import DEFAULT_SPEECH_TOKEN
|
||||||
#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
|
||||||
#from zhconv import convert
|
|
||||||
import transformers
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from model import EncoderProjector, SPEECH_LLM
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -63,8 +75,7 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
from train import DEFAULT_SPEECH_TOKEN
|
|
||||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
||||||
|
|
||||||
def average_checkpoints(
|
def average_checkpoints(
|
||||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||||
@ -117,6 +128,7 @@ def average_checkpoints(
|
|||||||
|
|
||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llm-path-or-name",
|
"--llm-path-or-name",
|
||||||
@ -135,7 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-projector-ds-rate",
|
"--encoder-projector-ds-rate",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=8,
|
||||||
help="Downsample rate for the encoder projector.",
|
help="Downsample rate for the encoder projector.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -149,10 +161,11 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-lora",
|
"--use-lora",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="Whether to use lora to fine-tune llm.",
|
help="Whether to use lora fine-tuned llm checkpoint.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -247,6 +260,7 @@ def decode_one_batch(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "beam-search".
|
Return a dict, whose key may be "beam-search".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
messages,
|
messages,
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
@ -268,10 +282,16 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
max_len_texts = max([len(text) for text in texts])
|
max_len_texts = max([len(text) for text in texts])
|
||||||
if tokenizer.padding_side == 'right':
|
if tokenizer.padding_side == "right":
|
||||||
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
|
texts = [
|
||||||
|
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
|
texts = [
|
||||||
|
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
|
||||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
|
|
||||||
@ -302,16 +322,18 @@ def decode_one_batch(
|
|||||||
feature_len = supervisions["num_frames"]
|
feature_len = supervisions["num_frames"]
|
||||||
feature_len = feature_len.to(device, dtype=dtype)
|
feature_len = feature_len.to(device, dtype=dtype)
|
||||||
|
|
||||||
messages = [[
|
messages = [
|
||||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
[
|
||||||
{"role": "assistant", "content": ""},
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||||
]] * len(feature)
|
{"role": "assistant", "content": ""},
|
||||||
|
]
|
||||||
|
] * len(feature)
|
||||||
|
|
||||||
input_ids, attention_mask = preprocess(
|
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
|
||||||
messages, tokenizer, max_len=128
|
|
||||||
|
generated_ids = model.decode(
|
||||||
|
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
|
|
||||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
return {"beam-search": hyps}
|
return {"beam-search": hyps}
|
||||||
@ -497,14 +519,14 @@ def main():
|
|||||||
|
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
# torch_dtype=torch.bfloat16
|
# torch_dtype=torch.bfloat16 FIX ME
|
||||||
torch_dtype=torch.float16
|
torch_dtype = torch.float16
|
||||||
tokenizer.padding_side = 'left'
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
torch_dtype=torch.float16
|
torch_dtype = torch.float16
|
||||||
tokenizer.padding_side = 'right'
|
tokenizer.padding_side = "right"
|
||||||
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
params.llm_path_or_name,
|
params.llm_path_or_name,
|
||||||
@ -515,23 +537,33 @@ def main():
|
|||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=64,
|
r=64,
|
||||||
lora_alpha=16,
|
lora_alpha=16,
|
||||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
target_modules=[
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"up_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"down_proj",
|
||||||
|
],
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
llm = get_peft_model(llm, lora_config)
|
llm = get_peft_model(llm, lora_config)
|
||||||
llm.print_trainable_parameters()
|
llm.print_trainable_parameters()
|
||||||
|
|
||||||
special_tokens_dict = {
|
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
|
||||||
}
|
|
||||||
tokenizer.add_special_tokens(special_tokens_dict)
|
tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||||
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||||
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||||
|
|
||||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||||
|
DEFAULT_SPEECH_TOKEN
|
||||||
|
)
|
||||||
|
|
||||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
encoder_projector = EncoderProjector(
|
||||||
|
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||||
|
)
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
@ -539,7 +571,6 @@ def main():
|
|||||||
encoder_projector,
|
encoder_projector,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
assert start >= 1, start
|
assert start >= 1, start
|
||||||
@ -579,7 +610,7 @@ def main():
|
|||||||
#
|
#
|
||||||
if c.duration > 30.0:
|
if c.duration > 30.0:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/conformer_ctc/label_smoothing.py
|
|
@ -1,11 +1,20 @@
|
|||||||
from torch import nn
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from transformers.trainer_pt_utils import LabelSmoother
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
from icefall.dist import get_rank
|
|
||||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
|
|
||||||
|
|
||||||
class EncoderProjector(nn.Module):
|
class EncoderProjector(nn.Module):
|
||||||
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
|
"""
|
||||||
|
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
|
||||||
|
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
|
||||||
|
Args:
|
||||||
|
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
|
||||||
|
llm_dim (:obj:`int`): The dimension of the language model.
|
||||||
|
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
|
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.downsample_rate = downsample_rate
|
self.downsample_rate = downsample_rate
|
||||||
@ -20,16 +29,30 @@ class EncoderProjector(nn.Module):
|
|||||||
if num_frames_to_discard > 0:
|
if num_frames_to_discard > 0:
|
||||||
x = x[:, :-num_frames_to_discard, :]
|
x = x[:, :-num_frames_to_discard, :]
|
||||||
seq_len = x.size(1)
|
seq_len = x.size(1)
|
||||||
|
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
x = x.view(batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate)
|
x = x.view(
|
||||||
|
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
|
||||||
|
)
|
||||||
|
|
||||||
x = self.linear1(x)
|
x = self.linear1(x)
|
||||||
x = self.relu(x)
|
x = self.relu(x)
|
||||||
x = self.linear2(x)
|
x = self.linear2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SPEECH_LLM(nn.Module):
|
class SPEECH_LLM(nn.Module):
|
||||||
|
"""
|
||||||
|
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
|
||||||
|
The encoder is used to extract speech features from the input speech signal.
|
||||||
|
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
|
||||||
|
The language model is used to generate the text from the speech features.
|
||||||
|
Args:
|
||||||
|
encoder (:obj:`nn.Module`): The encoder module.
|
||||||
|
llm (:obj:`nn.Module`): The language model module.
|
||||||
|
encoder_projector (:obj:`nn.Module`): The encoder projector module.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: nn.Module,
|
encoder: nn.Module,
|
||||||
@ -41,23 +64,46 @@ class SPEECH_LLM(nn.Module):
|
|||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.encoder_projector = encoder_projector
|
self.encoder_projector = encoder_projector
|
||||||
|
|
||||||
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
def _merge_input_ids_with_speech_features(
|
||||||
|
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
|
||||||
|
with the speech features and padding the input_ids to the maximum length of the speech features.
|
||||||
|
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
|
||||||
|
Args:
|
||||||
|
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
|
||||||
|
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
|
||||||
|
input_ids (:obj:`torch.Tensor`): The input ids to merge.
|
||||||
|
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
|
||||||
|
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
|
||||||
|
Returns:
|
||||||
|
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
|
||||||
|
"""
|
||||||
num_speechs, speech_len, embed_dim = speech_features.shape
|
num_speechs, speech_len, embed_dim = speech_features.shape
|
||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
|
left_padding = not torch.sum(
|
||||||
|
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
|
||||||
|
)
|
||||||
# 1. Create a mask to know where special speech tokens are
|
# 1. Create a mask to know where special speech tokens are
|
||||||
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
|
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
|
||||||
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
|
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
|
||||||
# Compute the maximum embed dimension
|
# Compute the maximum embed dimension
|
||||||
max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length
|
max_embed_dim = (
|
||||||
batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id)
|
num_special_speech_tokens.max() * (speech_len - 1)
|
||||||
|
) + sequence_length
|
||||||
|
batch_indices, non_speech_indices = torch.where(
|
||||||
|
input_ids != self.llm.config.default_speech_token_id
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Compute the positions where text should be written
|
# 2. Compute the positions where text should be written
|
||||||
# Calculate new positions for text tokens in merged speech-text sequence.
|
# Calculate new positions for text tokens in merged speech-text sequence.
|
||||||
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
|
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
|
||||||
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
|
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
|
||||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||||
new_token_positions = torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
|
new_token_positions = (
|
||||||
|
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
|
||||||
|
)
|
||||||
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||||
if left_padding:
|
if left_padding:
|
||||||
new_token_positions += nb_speech_pad[:, None] # offset for left padding
|
new_token_positions += nb_speech_pad[:, None] # offset for left padding
|
||||||
@ -65,14 +111,24 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
# 3. Create the full embedding, already padded to the maximum position
|
# 3. Create the full embedding, already padded to the maximum position
|
||||||
final_embedding = torch.zeros(
|
final_embedding = torch.zeros(
|
||||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
batch_size,
|
||||||
|
max_embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
dtype=inputs_embeds.dtype,
|
||||||
|
device=inputs_embeds.device,
|
||||||
)
|
)
|
||||||
final_attention_mask = torch.zeros(
|
final_attention_mask = torch.zeros(
|
||||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
batch_size,
|
||||||
|
max_embed_dim,
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=inputs_embeds.device,
|
||||||
)
|
)
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
final_labels = torch.full(
|
final_labels = torch.full(
|
||||||
(batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device
|
(batch_size, max_embed_dim),
|
||||||
|
IGNORE_TOKEN_ID,
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
device=input_ids.device,
|
||||||
)
|
)
|
||||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||||
# set the corresponding tensors into their correct target device.
|
# set the corresponding tensors into their correct target device.
|
||||||
@ -86,17 +142,28 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
|
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
|
||||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
|
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
|
||||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_speech_indices]
|
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
|
||||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_speech_indices]
|
batch_indices, non_speech_indices
|
||||||
|
]
|
||||||
|
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
|
||||||
|
batch_indices, non_speech_indices
|
||||||
|
]
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_speech_indices]
|
final_labels[batch_indices, text_to_overwrite] = labels[
|
||||||
|
batch_indices, non_speech_indices
|
||||||
|
]
|
||||||
|
|
||||||
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
|
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
|
||||||
speech_to_overwrite = torch.full(
|
speech_to_overwrite = torch.full(
|
||||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
(batch_size, max_embed_dim),
|
||||||
|
True,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=inputs_embeds.device,
|
||||||
)
|
)
|
||||||
speech_to_overwrite[batch_indices, text_to_overwrite] = False
|
speech_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||||
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[:, None].to(target_device)
|
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
|
||||||
|
:, None
|
||||||
|
].to(target_device)
|
||||||
|
|
||||||
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
|
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -104,12 +171,18 @@ class SPEECH_LLM(nn.Module):
|
|||||||
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
|
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
|
||||||
)
|
)
|
||||||
|
|
||||||
final_embedding[speech_to_overwrite] = speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
final_embedding[speech_to_overwrite] = (
|
||||||
|
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||||
|
)
|
||||||
final_attention_mask |= speech_to_overwrite
|
final_attention_mask |= speech_to_overwrite
|
||||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
|
||||||
|
(final_attention_mask == 0), 1
|
||||||
|
)
|
||||||
|
|
||||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||||
batch_indices, pad_indices = torch.where(input_ids == self.llm.config.pad_token_id)
|
batch_indices, pad_indices = torch.where(
|
||||||
|
input_ids == self.llm.config.pad_token_id
|
||||||
|
)
|
||||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||||
|
|
||||||
final_embedding[batch_indices, indices_to_mask] = 0
|
final_embedding[batch_indices, indices_to_mask] = 0
|
||||||
@ -119,62 +192,59 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
fbank: torch.Tensor = None,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
fbank: torch.Tensor = None,
|
||||||
attention_mask: torch.Tensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
labels: torch.LongTensor = None,
|
attention_mask: torch.Tensor = None,
|
||||||
):
|
labels: torch.LongTensor = None,
|
||||||
|
):
|
||||||
encoder_outs = self.encoder(fbank)
|
encoder_outs = self.encoder(fbank)
|
||||||
|
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
|
|
||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
enable_logging = False
|
|
||||||
rank = get_rank()
|
|
||||||
|
|
||||||
# log only on rank 0, training using deep
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
if enable_logging and rank == 0:
|
|
||||||
print("input_ids", input_ids, input_ids.shape)
|
(
|
||||||
print("labels", labels, labels.shape)
|
inputs_embeds,
|
||||||
print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
|
attention_mask,
|
||||||
print("attention_mask_before", attention_mask.shape, attention_mask)
|
labels,
|
||||||
print(2333333333333333333333333333)
|
_,
|
||||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
) = self._merge_input_ids_with_speech_features(
|
||||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
if enable_logging and rank == 0:
|
|
||||||
print("speech_features", speech_features.shape, speech_features)
|
|
||||||
print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
|
|
||||||
print("attention_mask", attention_mask.shape, attention_mask)
|
|
||||||
print("position_ids", position_ids.shape, position_ids)
|
|
||||||
print("labels", labels, labels.shape)
|
|
||||||
print("================================================================")
|
|
||||||
|
|
||||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
model_outputs = self.llm(
|
||||||
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
preds = torch.argmax(model_outputs.logits, -1)
|
preds = torch.argmax(model_outputs.logits, -1)
|
||||||
if enable_logging and rank == 0:
|
acc = compute_accuracy(
|
||||||
print("preds", preds, preds.shape)
|
preds.detach()[:, :-1],
|
||||||
print(4555555555555555555555555555555555555555555)
|
labels.detach()[:, 1:],
|
||||||
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
|
ignore_label=IGNORE_TOKEN_ID,
|
||||||
|
)
|
||||||
return model_outputs, acc
|
return model_outputs, acc
|
||||||
|
|
||||||
|
def decode(
|
||||||
def decode(self,
|
self,
|
||||||
fbank: torch.Tensor = None,
|
fbank: torch.Tensor = None,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: torch.Tensor = None,
|
attention_mask: torch.Tensor = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
encoder_outs = self.encoder(fbank)
|
encoder_outs = self.encoder(fbank)
|
||||||
speech_features = self.encoder_projector(encoder_outs)
|
speech_features = self.encoder_projector(encoder_outs)
|
||||||
speech_features = speech_features.to(torch.float16)
|
speech_features = speech_features.to(torch.float16)
|
||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
|
(
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask,
|
||||||
|
_,
|
||||||
|
position_ids,
|
||||||
|
) = self._merge_input_ids_with_speech_features(
|
||||||
speech_features, inputs_embeds, input_ids, attention_mask
|
speech_features, inputs_embeds, input_ids, attention_mask
|
||||||
)
|
)
|
||||||
generated_ids = self.llm.generate(
|
generated_ids = self.llm.generate(
|
||||||
@ -189,7 +259,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
temperature=kwargs.get("temperature", 1.0),
|
temperature=kwargs.get("temperature", 1.0),
|
||||||
bos_token_id=self.llm.config.bos_token_id,
|
bos_token_id=self.llm.config.bos_token_id,
|
||||||
eos_token_id=self.llm.config.eos_token_id,
|
eos_token_id=self.llm.config.eos_token_id,
|
||||||
pad_token_id=self.llm.config.pad_token_id
|
pad_token_id=self.llm.config.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return generated_ids
|
return generated_ids
|
||||||
@ -197,7 +267,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
|
|
||||||
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||||
"""Calculate accuracy.
|
"""Calculate accuracy.
|
||||||
|
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
|
||||||
Args:
|
Args:
|
||||||
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
|
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
|
||||||
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
||||||
@ -212,4 +282,4 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
|||||||
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
|
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
|
||||||
)
|
)
|
||||||
denominator = torch.sum(mask)
|
denominator = torch.sum(mask)
|
||||||
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
|
return numerator.float() / denominator.float()
|
||||||
|
@ -248,8 +248,6 @@ class MultiDataset:
|
|||||||
|
|
||||||
def aishell_train_cuts(self) -> CutSet:
|
def aishell_train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset train cuts")
|
logging.info("About to get multidataset train cuts")
|
||||||
|
|
||||||
# AISHELL-1
|
|
||||||
logging.info("Loading Aishell-1 in lazy mode")
|
logging.info("Loading Aishell-1 in lazy mode")
|
||||||
aishell_cuts = load_manifest_lazy(
|
aishell_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
||||||
@ -257,11 +255,8 @@ class MultiDataset:
|
|||||||
|
|
||||||
return aishell_cuts
|
return aishell_cuts
|
||||||
|
|
||||||
|
|
||||||
def aishell_dev_cuts(self) -> CutSet:
|
def aishell_dev_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset dev cuts")
|
logging.info("About to get multidataset dev cuts")
|
||||||
|
|
||||||
# AISHELL
|
|
||||||
logging.info("Loading Aishell set in lazy mode")
|
logging.info("Loading Aishell set in lazy mode")
|
||||||
aishell_dev_cuts = load_manifest_lazy(
|
aishell_dev_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||||
@ -271,8 +266,6 @@ class MultiDataset:
|
|||||||
|
|
||||||
def aishell_test_cuts(self) -> CutSet:
|
def aishell_test_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset test cuts")
|
logging.info("About to get multidataset test cuts")
|
||||||
|
|
||||||
# AISHELL
|
|
||||||
logging.info("Loading Aishell set in lazy mode")
|
logging.info("Loading Aishell set in lazy mode")
|
||||||
aishell_test_cuts = load_manifest_lazy(
|
aishell_test_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||||
@ -282,12 +275,8 @@ class MultiDataset:
|
|||||||
"aishell_test": aishell_test_cuts,
|
"aishell_test": aishell_test_cuts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# aishell 2
|
|
||||||
def aishell2_train_cuts(self) -> CutSet:
|
def aishell2_train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset train cuts")
|
logging.info("About to get multidataset train cuts")
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 in lazy mode")
|
logging.info("Loading Aishell-2 in lazy mode")
|
||||||
aishell_2_cuts = load_manifest_lazy(
|
aishell_2_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||||
@ -297,8 +286,6 @@ class MultiDataset:
|
|||||||
|
|
||||||
def aishell2_dev_cuts(self) -> CutSet:
|
def aishell2_dev_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset dev cuts")
|
logging.info("About to get multidataset dev cuts")
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 set in lazy mode")
|
logging.info("Loading Aishell-2 set in lazy mode")
|
||||||
aishell2_dev_cuts = load_manifest_lazy(
|
aishell2_dev_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||||
@ -308,8 +295,6 @@ class MultiDataset:
|
|||||||
|
|
||||||
def aishell2_test_cuts(self) -> CutSet:
|
def aishell2_test_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset test cuts")
|
logging.info("About to get multidataset test cuts")
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 set in lazy mode")
|
logging.info("Loading Aishell-2 set in lazy mode")
|
||||||
aishell2_test_cuts = load_manifest_lazy(
|
aishell2_test_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||||
@ -321,8 +306,6 @@ class MultiDataset:
|
|||||||
|
|
||||||
def wenetspeech_test_meeting_cuts(self) -> CutSet:
|
def wenetspeech_test_meeting_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset test cuts")
|
logging.info("About to get multidataset test cuts")
|
||||||
|
|
||||||
# WeNetSpeech
|
|
||||||
logging.info("Loading WeNetSpeech set in lazy mode")
|
logging.info("Loading WeNetSpeech set in lazy mode")
|
||||||
wenetspeech_test_meeting_cuts = load_manifest_lazy(
|
wenetspeech_test_meeting_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
|
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
|
||||||
@ -352,4 +335,4 @@ class MultiDataset:
|
|||||||
test_cuts = load_manifest_lazy(self.fbank_dir / path)
|
test_cuts = load_manifest_lazy(self.fbank_dir / path)
|
||||||
results_dict[partition] = test_cuts
|
results_dict[partition] = test_cuts
|
||||||
|
|
||||||
return results_dict
|
return results_dict
|
||||||
|
3
egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt
Executable file → Normal file
3
egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt
Executable file → Normal file
@ -5,9 +5,6 @@ sentencepiece
|
|||||||
pypinyin
|
pypinyin
|
||||||
tensorboard
|
tensorboard
|
||||||
librosa
|
librosa
|
||||||
# git+https://github.com/yuekaizhang/whisper.git
|
|
||||||
# zhconv
|
|
||||||
# WeTextProcessing
|
|
||||||
deepspeed
|
deepspeed
|
||||||
transformers>=4.37.0
|
transformers>=4.37.0
|
||||||
flash-attn
|
flash-attn
|
||||||
|
@ -17,14 +17,28 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
# fine-tuning with whisper and Qwen2
|
||||||
|
pip install huggingface_hub['cli']
|
||||||
|
mkdir -p models/whisper models/qwen
|
||||||
|
|
||||||
#fine-tuning with deepspeed zero stage 1
|
# For aishell fine-tuned whisper model
|
||||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||||
|
# For multi-hans fine-tuned whisper model
|
||||||
|
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||||
|
|
||||||
|
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||||
|
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
||||||
|
|
||||||
|
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--exp-dir whisper/exp_large_v2 \
|
--exp-dir ./whisper_llm_zh/exp_test \
|
||||||
--model-name large-v2 \
|
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||||
|
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--use-lora True --unfreeze-llm True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -39,36 +53,29 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import deepspeed
|
import deepspeed
|
||||||
import k2
|
import k2
|
||||||
# import optim
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
import whisper
|
import whisper
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID
|
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
from label_smoothing import LabelSmoothingLoss
|
from label_smoothing import LabelSmoothingLoss
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
# from optim import Eden, ScaledAdam
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
|
||||||
from torch.nn.functional import pad as pad_tensor
|
|
||||||
# from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.dist import get_rank, get_world_size
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|
||||||
from icefall.checkpoint import update_averaged_model
|
|
||||||
from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
@ -77,20 +84,15 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
import transformers
|
|
||||||
from transformers.trainer_pt_utils import LabelSmoother
|
|
||||||
|
|
||||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
||||||
|
|
||||||
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
|
||||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||||
|
|
||||||
|
|
||||||
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if hasattr(module, "batch_count"):
|
if hasattr(module, "batch_count"):
|
||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llm-path-or-name",
|
"--llm-path-or-name",
|
||||||
@ -109,7 +111,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-projector-ds-rate",
|
"--encoder-projector-ds-rate",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=8,
|
||||||
help="Downsample rate for the encoder projector.",
|
help="Downsample rate for the encoder projector.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -133,6 +135,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Whether to unfreeze llm during training.",
|
help="Whether to unfreeze llm during training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -162,15 +165,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--start-batch",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --start-epoch is ignored and
|
|
||||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -198,26 +192,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lr-batches",
|
|
||||||
type=float,
|
|
||||||
default=5000,
|
|
||||||
help="""Number of steps that affects how rapidly the learning rate
|
|
||||||
decreases. We suggest not to change this.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lr-epochs",
|
|
||||||
type=float,
|
|
||||||
default=6,
|
|
||||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -225,44 +199,6 @@ def get_parser():
|
|||||||
help="The seed for random generators intended for reproducibility",
|
help="The seed for random generators intended for reproducibility",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--print-diagnostics",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Accumulate stats on activations, print them and exit.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--inf-check",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Add hooks to check for infinite module outputs and gradients.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--keep-last-k",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="""Only keep this number of checkpoints on disk.
|
|
||||||
For instance, if it is 3, there are only 3 checkpoints
|
|
||||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
|
||||||
It does not affect checkpoints with name `epoch-xxx.pt`.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--average-period",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="""Update the averaged model, namely `model_avg`, after processing
|
|
||||||
this number of batches. `model_avg` is a separate version of model,
|
|
||||||
in which each floating-point parameter is the average of all the
|
|
||||||
parameters from the start of training. Each time we take the average,
|
|
||||||
we do: `model_avg = model * (average_period / batch_idx_train) +
|
|
||||||
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-fp16",
|
"--use-fp16",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -325,6 +261,7 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
@ -372,17 +309,23 @@ def compute_loss(
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
chat_template=TEMPLATE,
|
chat_template=TEMPLATE,
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
padding="longest", # FIX me change padding to longest
|
padding="longest", # FIX me change padding to longest
|
||||||
max_length=max_len,
|
max_length=max_len,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
||||||
max_len_texts = max([len(text) for text in texts])
|
max_len_texts = max([len(text) for text in texts])
|
||||||
if tokenizer.padding_side == 'right':
|
if tokenizer.padding_side == "right":
|
||||||
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
|
texts = [
|
||||||
|
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
|
texts = [
|
||||||
|
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||||
target_ids = input_ids.clone()
|
target_ids = input_ids.clone()
|
||||||
@ -391,13 +334,14 @@ def compute_loss(
|
|||||||
# first get the indices of the tokens
|
# first get the indices of the tokens
|
||||||
mask_prompt = True
|
mask_prompt = True
|
||||||
if mask_prompt:
|
if mask_prompt:
|
||||||
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
|
mask_indices = torch.where(
|
||||||
# then mask all tokens before the first token e.g. 151646 (speech), 151645 <assistant>, 198 \n
|
input_ids == tokenizer.convert_tokens_to_ids("assistant")
|
||||||
|
)
|
||||||
for i in range(mask_indices[0].size(0)):
|
for i in range(mask_indices[0].size(0)):
|
||||||
row = mask_indices[0][i]
|
row = mask_indices[0][i]
|
||||||
col = mask_indices[1][i]
|
col = mask_indices[1][i]
|
||||||
# + 2 to skip: 'assistant', '\n'
|
# + 2 to skip: 'assistant', '\n'
|
||||||
target_ids[row, :col+2] = IGNORE_TOKEN_ID
|
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
|
||||||
|
|
||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
|
|
||||||
@ -458,20 +402,13 @@ def compute_loss(
|
|||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
# message = [
|
|
||||||
# {"role": "system", "content": "你是一个能处理音频的助手。"},
|
|
||||||
# {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
|
||||||
# {"role": "assistant", "content": text},
|
|
||||||
# ]
|
|
||||||
message = [
|
message = [
|
||||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||||
{"role": "assistant", "content": text},
|
{"role": "assistant", "content": text},
|
||||||
]
|
]
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
input_ids, attention_mask, target_ids = preprocess(
|
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128)
|
||||||
messages, tokenizer, max_len=128
|
|
||||||
)
|
|
||||||
|
|
||||||
target_ids = target_ids.type(torch.LongTensor)
|
target_ids = target_ids.type(torch.LongTensor)
|
||||||
input_ids = input_ids.type(torch.LongTensor)
|
input_ids = input_ids.type(torch.LongTensor)
|
||||||
@ -494,7 +431,9 @@ def compute_loss(
|
|||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames
|
info["acc"] = (
|
||||||
|
acc * info["frames"]
|
||||||
|
) # WAR: to avoid normalization by the number of frames
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -607,7 +546,7 @@ def train_one_epoch(
|
|||||||
save_dir=params.exp_dir,
|
save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
client_state={},
|
client_state={},
|
||||||
exclude_frozen_parameters=True
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -702,29 +641,26 @@ def run(rank, world_size, args):
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
# if 'whisper' in params.speech_encoder_path_or_name:
|
|
||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
# TODO: directly loading from whisper-ft checkpoint
|
|
||||||
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
|
||||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||||
speech_encoder = whisper_model.encoder
|
speech_encoder = whisper_model.encoder
|
||||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||||
for name, param in speech_encoder.named_parameters():
|
for name, param in speech_encoder.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
speech_encoder.eval()
|
speech_encoder.eval()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
# torch_dtype=torch.bfloat16
|
# torch_dtype=torch.bfloat16 FIX ME
|
||||||
torch_dtype=torch.float16
|
torch_dtype = torch.float16
|
||||||
tokenizer.padding_side = 'left'
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
torch_dtype=torch.float16
|
torch_dtype = torch.float16
|
||||||
tokenizer.padding_side = 'right'
|
tokenizer.padding_side = "right"
|
||||||
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
params.llm_path_or_name,
|
params.llm_path_or_name,
|
||||||
@ -733,7 +669,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not params.unfreeze_llm:
|
if not params.unfreeze_llm:
|
||||||
for name, param in llm.named_parameters():
|
for name, param in llm.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
llm.eval()
|
llm.eval()
|
||||||
else:
|
else:
|
||||||
@ -741,21 +677,31 @@ def run(rank, world_size, args):
|
|||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=64,
|
r=64,
|
||||||
lora_alpha=16,
|
lora_alpha=16,
|
||||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
target_modules=[
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"up_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"down_proj",
|
||||||
|
],
|
||||||
lora_dropout=0.05,
|
lora_dropout=0.05,
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
llm = get_peft_model(llm, lora_config)
|
llm = get_peft_model(llm, lora_config)
|
||||||
llm.print_trainable_parameters()
|
llm.print_trainable_parameters()
|
||||||
|
|
||||||
special_tokens_dict = {
|
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
|
||||||
}
|
|
||||||
tokenizer.add_special_tokens(special_tokens_dict)
|
tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||||
|
DEFAULT_SPEECH_TOKEN
|
||||||
|
)
|
||||||
|
|
||||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
encoder_projector = EncoderProjector(
|
||||||
|
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||||
|
)
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
@ -806,7 +752,7 @@ def run(rank, world_size, args):
|
|||||||
# )
|
# )
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if params.use_aishell:
|
if params.use_aishell:
|
||||||
train_cuts = multi_dataset.aishell_train_cuts()
|
train_cuts = multi_dataset.aishell_train_cuts()
|
||||||
else:
|
else:
|
||||||
@ -814,12 +760,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
|
||||||
# # We only load the sampler's state dict when it loads a checkpoint
|
|
||||||
# # saved in the middle of an epoch
|
|
||||||
# sampler_state_dict = checkpoints["sampler"]
|
|
||||||
# else:
|
|
||||||
# sampler_state_dict = None
|
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
if params.sampler_state_dict_path:
|
if params.sampler_state_dict_path:
|
||||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||||
@ -840,13 +780,6 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
tb_writer = None
|
tb_writer = None
|
||||||
|
|
||||||
# if params.pretrained_model_path:
|
|
||||||
# checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
|
||||||
# if "model" not in checkpoint:
|
|
||||||
# model.load_state_dict(checkpoint, strict=True)
|
|
||||||
# else:
|
|
||||||
# load_checkpoint(params.pretrained_model_path, model)
|
|
||||||
|
|
||||||
logging.info(f"start training from epoch {params.start_epoch}")
|
logging.info(f"start training from epoch {params.start_epoch}")
|
||||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
|
|
||||||
@ -871,12 +804,11 @@ def run(rank, world_size, args):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
model.save_checkpoint(
|
model.save_checkpoint(
|
||||||
save_dir=params.exp_dir,
|
save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"epoch-{params.cur_epoch}",
|
||||||
client_state={},
|
client_state={},
|
||||||
exclude_frozen_parameters=True
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
@ -887,13 +819,16 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
# save sampler state dict into checkpoint
|
# save sampler state dict into checkpoint
|
||||||
sampler_state_dict = train_dl.sampler.state_dict()
|
sampler_state_dict = train_dl.sampler.state_dict()
|
||||||
torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt")
|
torch.save(
|
||||||
|
sampler_state_dict,
|
||||||
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
||||||
|
)
|
||||||
|
|
||||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||||
|
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
def display_and_save_batch(
|
def display_and_save_batch(
|
||||||
batch: dict,
|
batch: dict,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user