mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add readme
This commit is contained in:
parent
d1e31c7ac7
commit
9ed428d7b1
@ -1,39 +1,20 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes scripts for training Zipformer model using multiple Chinese datasets.
|
||||
This recipe includes scripts for training [Qwen-Audio](https://github.com/QwenLM/Qwen-Audio/tree/main) style model using multiple datasets.
|
||||
|
||||
# Included Training Sets
|
||||
1. THCHS-30
|
||||
2. AiShell-{1,2,4}
|
||||
3. ST-CMDS
|
||||
4. Primewords
|
||||
5. MagicData
|
||||
6. Aidatatang_200zh
|
||||
7. AliMeeting
|
||||
8. WeNetSpeech
|
||||
9. KeSpeech-ASR
|
||||
<br>
|
||||
<p align="center">
|
||||
<img src="assets/framework.png" width="800"/>
|
||||
<p>
|
||||
<br>
|
||||
|
||||
|Datset| Number of hours| URL|
|
||||
|---|---:|---|
|
||||
|**TOTAL**|14,106|---|
|
||||
|THCHS-30|35|https://www.openslr.org/18/|
|
||||
|AiShell-1|170|https://www.openslr.org/33/|
|
||||
|AiShell-2|1,000|http://www.aishelltech.com/aishell_2|
|
||||
|AiShell-4|120|https://www.openslr.org/111/|
|
||||
|ST-CMDS|110|https://www.openslr.org/38/|
|
||||
|Primewords|99|https://www.openslr.org/47/|
|
||||
|aidatatang_200zh|200|https://www.openslr.org/62/|
|
||||
|MagicData|755|https://www.openslr.org/68/|
|
||||
|AliMeeting|100|https://openslr.org/119/|
|
||||
|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech|
|
||||
|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech|
|
||||
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||
|
||||
# ASR_LLM
|
||||
|
||||
# Included Test Sets
|
||||
1. Aishell-{1,2,4}
|
||||
2. Aidatatang_200zh
|
||||
3. AliMeeting
|
||||
4. MagicData
|
||||
5. KeSpeech-ASR
|
||||
6. WeNetSpeech
|
||||
The following table lists the folders for different tasks.
|
||||
|
||||
| | Speech Encoder | LLM | Comment |
|
||||
|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
|
||||
| [whisper_llm_zh](./whisper_llm_zh) | Whisper | Qwen2 | [Using multiple Chinese datasets](https://github.com/k2-fsa/icefall/tree/master/egs/multi_zh-hans/ASR) |
|
||||
|
@ -1,116 +1,62 @@
|
||||
## Results
|
||||
|
||||
### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
|
||||
#### Whisper
|
||||
[./whisper](./whisper)
|
||||
### whisper_llm_zh finetuning results
|
||||
|
||||
Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.
|
||||
|
||||
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|
||||
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
|
||||
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
|
||||
| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
|
||||
| Training Dataset | Speech Encoder | LLM | Projector |Comment | CER |
|
||||
| -------------------------| ----------------|------|--------------------------------------------------|-----|--|
|
||||
| Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 Test 3.76% |
|
||||
<!-- | Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| WIP ||
|
||||
| Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-7B-Instruct, LoRA | Linear, 8x downsample| WIP || -->
|
||||
|
||||
Command for training is:
|
||||
```bash
|
||||
pip install -r whisper/requirements.txt
|
||||
pip install -r whisper_llm_zh/requirements.txt
|
||||
|
||||
# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54
|
||||
pip install huggingface_hub['cli']
|
||||
mkdir -p models/whisper models/qwen
|
||||
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
||||
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True
|
||||
```
|
||||
|
||||
Command for decoding using fine-tuned models:
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
|
||||
ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt
|
||||
mkdir -p models/whisper models/qwen models/checkpoint
|
||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
|
||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||
|
||||
python3 ./whisper_llm_zh/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name models/qwen \
|
||||
--epoch 999 --avg 1 \
|
||||
--beam-size 10 --max-duration 50
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --dataset aishell
|
||||
```
|
||||
|
||||
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
|
||||
|
||||
|
||||
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
|
||||
|
||||
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
|
||||
|
||||
#### Non-streaming (with CTC head)
|
||||
|
||||
Best results (num of params : ~69M):
|
||||
|
||||
The training command:
|
||||
|
||||
```
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 20 \
|
||||
--use-fp16 1 \
|
||||
--max-duration 600 \
|
||||
--num-workers 8 \
|
||||
--use-ctc 1
|
||||
```
|
||||
|
||||
The decoding command:
|
||||
|
||||
```
|
||||
./zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-ctc 1
|
||||
```
|
||||
|
||||
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled).
|
||||
|
||||
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
||||
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
||||
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
||||
| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 |
|
||||
| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 |
|
||||
|
||||
Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/
|
||||
|
||||
#### Non-streaming
|
||||
|
||||
Best results (num of params : ~69M):
|
||||
|
||||
The training command:
|
||||
|
||||
```
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 20 \
|
||||
--use-fp16 1 \
|
||||
--max-duration 600 \
|
||||
--num-workers 8
|
||||
```
|
||||
|
||||
The decoding command:
|
||||
|
||||
```
|
||||
./zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1
|
||||
```
|
||||
|
||||
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled).
|
||||
|
||||
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
||||
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
||||
| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
||||
| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |
|
||||
|
||||
|
||||
Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
|
||||
|
BIN
egs/speech_llm/ASR_LLM/assets/framework.png
Normal file
BIN
egs/speech_llm/ASR_LLM/assets/framework.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 834 KiB |
46
egs/speech_llm/ASR_LLM/prepare.sh
Normal file
46
egs/speech_llm/ASR_LLM/prepare.sh
Normal file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=0
|
||||
stop_stage=0
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: Download whisper-large-v2 aishell 1 fbank feature from huggingface"
|
||||
|
||||
# pip install huggingface_hub['cli']
|
||||
# for aishell 1
|
||||
huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
|
||||
|
||||
# for multi-hans-zh
|
||||
huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
|
||||
huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
|
||||
huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "stage 2: Download whisper-large-v2 speechio test sets fbank feature from huggingface"
|
||||
|
||||
# for speechio test sets
|
||||
mkdir data_speechio
|
||||
huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio
|
||||
mv data_speechio/fbank/* data/fbank
|
||||
fi
|
@ -20,21 +20,34 @@
|
||||
"""
|
||||
Usage:
|
||||
# Command for decoding using fine-tuned models:
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
|
||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
pip install huggingface_hub['cli']
|
||||
mkdir -p models/whisper models/qwen models/checkpoint
|
||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
|
||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||
|
||||
python3 ./whisper_llm_zh/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name models/qwen \
|
||||
--epoch 999 --avg 1 \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --dataset aishell
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -42,18 +55,17 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from model import SPEECH_LLM, EncoderProjector
|
||||
from multi_dataset import MultiDataset
|
||||
#from tn.chinese.normalizer import Normalizer
|
||||
#from whisper.normalizers import BasicTextNormalizer
|
||||
#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
#from zhconv import convert
|
||||
import transformers
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from train import DEFAULT_SPEECH_TOKEN
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from model import EncoderProjector, SPEECH_LLM
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
@ -63,8 +75,7 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
from train import DEFAULT_SPEECH_TOKEN
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
@ -117,6 +128,7 @@ def average_checkpoints(
|
||||
|
||||
return avg
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--llm-path-or-name",
|
||||
@ -135,7 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder-projector-ds-rate",
|
||||
type=int,
|
||||
default=1,
|
||||
default=8,
|
||||
help="Downsample rate for the encoder projector.",
|
||||
)
|
||||
|
||||
@ -149,10 +161,11 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--use-lora",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use lora to fine-tune llm.",
|
||||
default=True,
|
||||
help="Whether to use lora fine-tuned llm checkpoint.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -247,6 +260,7 @@ def decode_one_batch(
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
@ -268,10 +282,16 @@ def decode_one_batch(
|
||||
)
|
||||
)
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == 'right':
|
||||
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
@ -302,16 +322,18 @@ def decode_one_batch(
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
|
||||
messages = [[
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]] * len(feature)
|
||||
messages = [
|
||||
[
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
] * len(feature)
|
||||
|
||||
input_ids, attention_mask = preprocess(
|
||||
messages, tokenizer, max_len=128
|
||||
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
|
||||
|
||||
generated_ids = model.decode(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
|
||||
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
return {"beam-search": hyps}
|
||||
@ -497,14 +519,14 @@ def main():
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
# torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.float16
|
||||
tokenizer.padding_side = 'left'
|
||||
# torch_dtype=torch.bfloat16 FIX ME
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype=torch.float16
|
||||
tokenizer.padding_side = 'right'
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
@ -515,23 +537,33 @@ def main():
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
}
|
||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
||||
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
||||
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
)
|
||||
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||
encoder_projector = EncoderProjector(
|
||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||
)
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
@ -539,7 +571,6 @@ def main():
|
||||
encoder_projector,
|
||||
)
|
||||
|
||||
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg + 1
|
||||
assert start >= 1, start
|
||||
@ -579,7 +610,7 @@ def main():
|
||||
#
|
||||
if c.duration > 30.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/conformer_ctc/label_smoothing.py
|
@ -1,11 +1,20 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
from icefall.dist import get_rank
|
||||
|
||||
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
|
||||
|
||||
class EncoderProjector(nn.Module):
|
||||
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
|
||||
"""
|
||||
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
|
||||
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
|
||||
Args:
|
||||
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
|
||||
llm_dim (:obj:`int`): The dimension of the language model.
|
||||
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
|
||||
super().__init__()
|
||||
self.downsample_rate = downsample_rate
|
||||
@ -22,14 +31,28 @@ class EncoderProjector(nn.Module):
|
||||
seq_len = x.size(1)
|
||||
|
||||
x = x.contiguous()
|
||||
x = x.view(batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate)
|
||||
x = x.view(
|
||||
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
|
||||
)
|
||||
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
class SPEECH_LLM(nn.Module):
|
||||
"""
|
||||
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
|
||||
The encoder is used to extract speech features from the input speech signal.
|
||||
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
|
||||
The language model is used to generate the text from the speech features.
|
||||
Args:
|
||||
encoder (:obj:`nn.Module`): The encoder module.
|
||||
llm (:obj:`nn.Module`): The language model module.
|
||||
encoder_projector (:obj:`nn.Module`): The encoder projector module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: nn.Module,
|
||||
@ -41,23 +64,46 @@ class SPEECH_LLM(nn.Module):
|
||||
self.llm = llm
|
||||
self.encoder_projector = encoder_projector
|
||||
|
||||
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
||||
def _merge_input_ids_with_speech_features(
|
||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||
):
|
||||
"""
|
||||
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
|
||||
with the speech features and padding the input_ids to the maximum length of the speech features.
|
||||
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
|
||||
Args:
|
||||
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
|
||||
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
|
||||
input_ids (:obj:`torch.Tensor`): The input ids to merge.
|
||||
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
|
||||
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
|
||||
Returns:
|
||||
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
|
||||
"""
|
||||
num_speechs, speech_len, embed_dim = speech_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
|
||||
left_padding = not torch.sum(
|
||||
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
|
||||
)
|
||||
# 1. Create a mask to know where special speech tokens are
|
||||
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
|
||||
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length
|
||||
batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id)
|
||||
max_embed_dim = (
|
||||
num_special_speech_tokens.max() * (speech_len - 1)
|
||||
) + sequence_length
|
||||
batch_indices, non_speech_indices = torch.where(
|
||||
input_ids != self.llm.config.default_speech_token_id
|
||||
)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged speech-text sequence.
|
||||
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
|
||||
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
|
||||
new_token_positions = (
|
||||
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
|
||||
)
|
||||
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_speech_pad[:, None] # offset for left padding
|
||||
@ -65,14 +111,24 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
batch_size,
|
||||
max_embed_dim,
|
||||
embed_dim,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
batch_size,
|
||||
max_embed_dim,
|
||||
dtype=attention_mask.dtype,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device
|
||||
(batch_size, max_embed_dim),
|
||||
IGNORE_TOKEN_ID,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
@ -86,17 +142,28 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_speech_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_speech_indices]
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
|
||||
batch_indices, non_speech_indices
|
||||
]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
|
||||
batch_indices, non_speech_indices
|
||||
]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_speech_indices]
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[
|
||||
batch_indices, non_speech_indices
|
||||
]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
|
||||
speech_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
(batch_size, max_embed_dim),
|
||||
True,
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
speech_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[:, None].to(target_device)
|
||||
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
|
||||
:, None
|
||||
].to(target_device)
|
||||
|
||||
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
@ -104,12 +171,18 @@ class SPEECH_LLM(nn.Module):
|
||||
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
final_embedding[speech_to_overwrite] = speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_embedding[speech_to_overwrite] = (
|
||||
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
)
|
||||
final_attention_mask |= speech_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
|
||||
(final_attention_mask == 0), 1
|
||||
)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.llm.config.pad_token_id)
|
||||
batch_indices, pad_indices = torch.where(
|
||||
input_ids == self.llm.config.pad_token_id
|
||||
)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
@ -119,62 +192,59 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
def forward(self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels: torch.LongTensor = None,
|
||||
):
|
||||
def forward(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
labels: torch.LongTensor = None,
|
||||
):
|
||||
encoder_outs = self.encoder(fbank)
|
||||
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
|
||||
enable_logging = False
|
||||
rank = get_rank()
|
||||
|
||||
# log only on rank 0, training using deep
|
||||
if enable_logging and rank == 0:
|
||||
print("input_ids", input_ids, input_ids.shape)
|
||||
print("labels", labels, labels.shape)
|
||||
print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
|
||||
print("attention_mask_before", attention_mask.shape, attention_mask)
|
||||
print(2333333333333333333333333333)
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
if enable_logging and rank == 0:
|
||||
print("speech_features", speech_features.shape, speech_features)
|
||||
print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
|
||||
print("attention_mask", attention_mask.shape, attention_mask)
|
||||
print("position_ids", position_ids.shape, position_ids)
|
||||
print("labels", labels, labels.shape)
|
||||
print("================================================================")
|
||||
|
||||
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
||||
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
if enable_logging and rank == 0:
|
||||
print("preds", preds, preds.shape)
|
||||
print(4555555555555555555555555555555555555555555)
|
||||
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
|
||||
acc = compute_accuracy(
|
||||
preds.detach()[:, :-1],
|
||||
labels.detach()[:, 1:],
|
||||
ignore_label=IGNORE_TOKEN_ID,
|
||||
)
|
||||
return model_outputs, acc
|
||||
|
||||
|
||||
def decode(self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
):
|
||||
def decode(
|
||||
self,
|
||||
fbank: torch.Tensor = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
encoder_outs = self.encoder(fbank)
|
||||
speech_features = self.encoder_projector(encoder_outs)
|
||||
speech_features = speech_features.to(torch.float16)
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
_,
|
||||
position_ids,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
speech_features, inputs_embeds, input_ids, attention_mask
|
||||
)
|
||||
generated_ids = self.llm.generate(
|
||||
@ -189,7 +259,7 @@ class SPEECH_LLM(nn.Module):
|
||||
temperature=kwargs.get("temperature", 1.0),
|
||||
bos_token_id=self.llm.config.bos_token_id,
|
||||
eos_token_id=self.llm.config.eos_token_id,
|
||||
pad_token_id=self.llm.config.pad_token_id
|
||||
pad_token_id=self.llm.config.pad_token_id,
|
||||
)
|
||||
|
||||
return generated_ids
|
||||
@ -197,7 +267,7 @@ class SPEECH_LLM(nn.Module):
|
||||
|
||||
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
|
||||
Args:
|
||||
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
||||
@ -212,4 +282,4 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
|
||||
)
|
||||
denominator = torch.sum(mask)
|
||||
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
|
||||
return numerator.float() / denominator.float()
|
||||
|
@ -248,8 +248,6 @@ class MultiDataset:
|
||||
|
||||
def aishell_train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# AISHELL-1
|
||||
logging.info("Loading Aishell-1 in lazy mode")
|
||||
aishell_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
||||
@ -257,11 +255,8 @@ class MultiDataset:
|
||||
|
||||
return aishell_cuts
|
||||
|
||||
|
||||
def aishell_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell set in lazy mode")
|
||||
aishell_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
@ -271,8 +266,6 @@ class MultiDataset:
|
||||
|
||||
def aishell_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell set in lazy mode")
|
||||
aishell_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||
@ -282,12 +275,8 @@ class MultiDataset:
|
||||
"aishell_test": aishell_test_cuts,
|
||||
}
|
||||
|
||||
|
||||
# aishell 2
|
||||
def aishell2_train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 in lazy mode")
|
||||
aishell_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
@ -297,8 +286,6 @@ class MultiDataset:
|
||||
|
||||
def aishell2_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 set in lazy mode")
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
@ -308,8 +295,6 @@ class MultiDataset:
|
||||
|
||||
def aishell2_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 set in lazy mode")
|
||||
aishell2_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||
@ -321,8 +306,6 @@ class MultiDataset:
|
||||
|
||||
def wenetspeech_test_meeting_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech set in lazy mode")
|
||||
wenetspeech_test_meeting_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
|
||||
|
3
egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt
Executable file → Normal file
3
egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt
Executable file → Normal file
@ -5,9 +5,6 @@ sentencepiece
|
||||
pypinyin
|
||||
tensorboard
|
||||
librosa
|
||||
# git+https://github.com/yuekaizhang/whisper.git
|
||||
# zhconv
|
||||
# WeTextProcessing
|
||||
deepspeed
|
||||
transformers>=4.37.0
|
||||
flash-attn
|
||||
|
@ -17,14 +17,28 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# fine-tuning with whisper and Qwen2
|
||||
pip install huggingface_hub['cli']
|
||||
mkdir -p models/whisper models/qwen
|
||||
|
||||
#fine-tuning with deepspeed zero stage 1
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
||||
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -39,36 +53,29 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import deepspeed
|
||||
import k2
|
||||
# import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||
from multi_dataset import MultiDataset
|
||||
# from optim import Eden, ScaledAdam
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.functional import pad as pad_tensor
|
||||
# from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import update_averaged_model
|
||||
from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
|
||||
from icefall.dist import get_rank, get_world_size
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
@ -77,20 +84,15 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import transformers
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
|
||||
|
||||
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--llm-path-or-name",
|
||||
@ -109,7 +111,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder-projector-ds-rate",
|
||||
type=int,
|
||||
default=1,
|
||||
default=8,
|
||||
help="Downsample rate for the encoder projector.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -133,6 +135,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Whether to unfreeze llm during training.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -162,15 +165,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-batch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --start-epoch is ignored and
|
||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -198,26 +192,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="""Number of steps that affects how rapidly the learning rate
|
||||
decreases. We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-epochs",
|
||||
type=float,
|
||||
default=6,
|
||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -225,44 +199,6 @@ def get_parser():
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--print-diagnostics",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--inf-check",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Add hooks to check for infinite module outputs and gradients.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep-last-k",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||
It does not affect checkpoints with name `epoch-xxx.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Update the averaged model, namely `model_avg`, after processing
|
||||
this number of batches. `model_avg` is a separate version of model,
|
||||
in which each floating-point parameter is the average of all the
|
||||
parameters from the start of training. Each time we take the average,
|
||||
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
@ -325,6 +261,7 @@ def get_params() -> AttributeDict:
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
@ -372,17 +309,23 @@ def compute_loss(
|
||||
tokenize=True,
|
||||
chat_template=TEMPLATE,
|
||||
add_generation_prompt=False,
|
||||
padding="longest", # FIX me change padding to longest
|
||||
padding="longest", # FIX me change padding to longest
|
||||
max_length=max_len,
|
||||
truncation=True,
|
||||
)
|
||||
)
|
||||
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == 'right':
|
||||
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||
target_ids = input_ids.clone()
|
||||
@ -391,13 +334,14 @@ def compute_loss(
|
||||
# first get the indices of the tokens
|
||||
mask_prompt = True
|
||||
if mask_prompt:
|
||||
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
|
||||
# then mask all tokens before the first token e.g. 151646 (speech), 151645 <assistant>, 198 \n
|
||||
mask_indices = torch.where(
|
||||
input_ids == tokenizer.convert_tokens_to_ids("assistant")
|
||||
)
|
||||
for i in range(mask_indices[0].size(0)):
|
||||
row = mask_indices[0][i]
|
||||
col = mask_indices[1][i]
|
||||
# + 2 to skip: 'assistant', '\n'
|
||||
target_ids[row, :col+2] = IGNORE_TOKEN_ID
|
||||
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
@ -458,20 +402,13 @@ def compute_loss(
|
||||
|
||||
messages = []
|
||||
for i, text in enumerate(texts):
|
||||
# message = [
|
||||
# {"role": "system", "content": "你是一个能处理音频的助手。"},
|
||||
# {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
|
||||
# {"role": "assistant", "content": text},
|
||||
# ]
|
||||
message = [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||
{"role": "assistant", "content": text},
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||
{"role": "assistant", "content": text},
|
||||
]
|
||||
messages.append(message)
|
||||
|
||||
input_ids, attention_mask, target_ids = preprocess(
|
||||
messages, tokenizer, max_len=128
|
||||
)
|
||||
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128)
|
||||
|
||||
target_ids = target_ids.type(torch.LongTensor)
|
||||
input_ids = input_ids.type(torch.LongTensor)
|
||||
@ -494,7 +431,9 @@ def compute_loss(
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames
|
||||
info["acc"] = (
|
||||
acc * info["frames"]
|
||||
) # WAR: to avoid normalization by the number of frames
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -607,7 +546,7 @@ def train_one_epoch(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
@ -703,10 +642,7 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
# if 'whisper' in params.speech_encoder_path_or_name:
|
||||
replace_whisper_encoder_forward()
|
||||
# TODO: directly loading from whisper-ft checkpoint
|
||||
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
@ -717,14 +653,14 @@ def run(rank, world_size, args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
# torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.float16
|
||||
tokenizer.padding_side = 'left'
|
||||
# torch_dtype=torch.bfloat16 FIX ME
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype=torch.float16
|
||||
tokenizer.padding_side = 'right'
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
params.llm_path_or_name,
|
||||
@ -741,21 +677,31 @@ def run(rank, world_size, args):
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"up_proj",
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
}
|
||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
DEFAULT_SPEECH_TOKEN
|
||||
)
|
||||
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||
encoder_projector = EncoderProjector(
|
||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||
)
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
@ -814,12 +760,6 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# # We only load the sampler's state dict when it loads a checkpoint
|
||||
# # saved in the middle of an epoch
|
||||
# sampler_state_dict = checkpoints["sampler"]
|
||||
# else:
|
||||
# sampler_state_dict = None
|
||||
sampler_state_dict = None
|
||||
if params.sampler_state_dict_path:
|
||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||
@ -840,13 +780,6 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
# if params.pretrained_model_path:
|
||||
# checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
# if "model" not in checkpoint:
|
||||
# model.load_state_dict(checkpoint, strict=True)
|
||||
# else:
|
||||
# load_checkpoint(params.pretrained_model_path, model)
|
||||
|
||||
logging.info(f"start training from epoch {params.start_epoch}")
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
|
||||
@ -871,12 +804,11 @@ def run(rank, world_size, args):
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
@ -887,13 +819,16 @@ def run(rank, world_size, args):
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
sampler_state_dict = train_dl.sampler.state_dict()
|
||||
torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt")
|
||||
torch.save(
|
||||
sampler_state_dict,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
||||
)
|
||||
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
|
Loading…
x
Reference in New Issue
Block a user