add readme

This commit is contained in:
root 2024-06-14 03:55:57 +00:00 committed by Yuekai Zhang
parent d1e31c7ac7
commit 9ed428d7b1
10 changed files with 397 additions and 409 deletions

View File

@ -1,39 +1,20 @@
# Introduction # Introduction
This recipe includes scripts for training Zipformer model using multiple Chinese datasets. This recipe includes scripts for training [Qwen-Audio](https://github.com/QwenLM/Qwen-Audio/tree/main) style model using multiple datasets.
# Included Training Sets <br>
1. THCHS-30 <p align="center">
2. AiShell-{1,2,4} <img src="assets/framework.png" width="800"/>
3. ST-CMDS <p>
4. Primewords <br>
5. MagicData
6. Aidatatang_200zh
7. AliMeeting
8. WeNetSpeech
9. KeSpeech-ASR
|Datset| Number of hours| URL| [./RESULTS.md](./RESULTS.md) contains the latest results.
|---|---:|---|
|**TOTAL**|14,106|---|
|THCHS-30|35|https://www.openslr.org/18/|
|AiShell-1|170|https://www.openslr.org/33/|
|AiShell-2|1,000|http://www.aishelltech.com/aishell_2|
|AiShell-4|120|https://www.openslr.org/111/|
|ST-CMDS|110|https://www.openslr.org/38/|
|Primewords|99|https://www.openslr.org/47/|
|aidatatang_200zh|200|https://www.openslr.org/62/|
|MagicData|755|https://www.openslr.org/68/|
|AliMeeting|100|https://openslr.org/119/|
|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech|
|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech|
# ASR_LLM
# Included Test Sets The following table lists the folders for different tasks.
1. Aishell-{1,2,4}
2. Aidatatang_200zh | | Speech Encoder | LLM | Comment |
3. AliMeeting |---------------------------------------|---------------------|--------------------|---------------------------------------------------|
4. MagicData | [whisper_llm_zh](./whisper_llm_zh) | Whisper | Qwen2 | [Using multiple Chinese datasets](https://github.com/k2-fsa/icefall/tree/master/egs/multi_zh-hans/ASR) |
5. KeSpeech-ASR
6. WeNetSpeech

View File

@ -1,116 +1,62 @@
## Results ## Results
### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2 ### whisper_llm_zh finetuning results
#### Whisper
[./whisper](./whisper)
Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search. | Training Dataset | Speech Encoder | LLM | Projector |Comment | CER |
| -------------------------| ----------------|------|--------------------------------------------------|-----|--|
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | | Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 Test 3.76% |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------| <!-- | Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| WIP ||
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting | | Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-7B-Instruct, LoRA | Linear, 8x downsample| WIP || -->
| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
Command for training is: Command for training is:
```bash ```bash
pip install -r whisper/requirements.txt pip install -r whisper_llm_zh/requirements.txt
# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54 pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
torchrun --nproc-per-node 8 ./whisper/train.py \ # For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \ --max-duration 200 \
--exp-dir whisper/exp_large_v2 \ --exp-dir ./whisper_llm_zh/exp_test \
--model-name large-v2 \ --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
``` ```
Command for decoding using fine-tuned models: Command for decoding using fine-tuned models:
```bash ```bash
git lfs install mkdir -p models/whisper models/qwen models/checkpoint
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt
python3 ./whisper/decode.py \ # For aishell fine-tuned whisper model
--exp-dir whisper/exp_large_v2 \ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
--model-name large-v2 \ # For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \ --epoch 999 --avg 1 \
--beam-size 10 --max-duration 50 --manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
``` ```
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
#### Non-streaming (with CTC head)
Best results (num of params : ~69M):
The training command:
```
./zipformer/train.py \
--world-size 4 \
--num-epochs 20 \
--use-fp16 1 \
--max-duration 600 \
--num-workers 8 \
--use-ctc 1
```
The decoding command:
```
./zipformer/decode.py \
--epoch 20 \
--avg 1 \
--use-ctc 1
```
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled).
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 |
| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 |
Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/
#### Non-streaming
Best results (num of params : ~69M):
The training command:
```
./zipformer/train.py \
--world-size 4 \
--num-epochs 20 \
--use-fp16 1 \
--max-duration 600 \
--num-workers 8
```
The decoding command:
```
./zipformer/decode.py \
--epoch 20 \
--avg 1
```
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled).
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |
Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/

Binary file not shown.

After

Width:  |  Height:  |  Size: 834 KiB

View File

@ -0,0 +1,46 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=0
stop_stage=0
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download whisper-large-v2 aishell 1 fbank feature from huggingface"
# pip install huggingface_hub['cli']
# for aishell 1
huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
# for multi-hans-zh
huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: Download whisper-large-v2 speechio test sets fbank feature from huggingface"
# for speechio test sets
mkdir data_speechio
huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio
mv data_speechio/fbank/* data/fbank
fi

View File

@ -20,21 +20,34 @@
""" """
Usage: Usage:
# Command for decoding using fine-tuned models: # Command for decoding using fine-tuned models:
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
python3 ./whisper/decode.py \ pip install huggingface_hub['cli']
--exp-dir whisper/exp_large_v2 \ mkdir -p models/whisper models/qwen models/checkpoint
--model-name large-v2 \ huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \ --epoch 999 --avg 1 \
--beam-size 10 --max-duration 50 --manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
""" """
import argparse import argparse
import logging import logging
import re
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -42,18 +55,17 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers
import whisper import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from lhotse.cut import Cut from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
#from tn.chinese.normalizer import Normalizer from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
#from whisper.normalizers import BasicTextNormalizer from train import DEFAULT_SPEECH_TOKEN
#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
#from zhconv import convert
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from model import EncoderProjector, SPEECH_LLM from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
@ -63,8 +75,7 @@ from icefall.utils import (
str2bool, str2bool,
write_error_stats, write_error_stats,
) )
from train import DEFAULT_SPEECH_TOKEN
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
def average_checkpoints( def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu") filenames: List[Path], device: torch.device = torch.device("cpu")
@ -117,6 +128,7 @@ def average_checkpoints(
return avg return avg
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--llm-path-or-name", "--llm-path-or-name",
@ -135,7 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-projector-ds-rate", "--encoder-projector-ds-rate",
type=int, type=int,
default=1, default=8,
help="Downsample rate for the encoder projector.", help="Downsample rate for the encoder projector.",
) )
@ -149,10 +161,11 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--use-lora", "--use-lora",
type=str2bool, type=str2bool,
default=False, default=True,
help="Whether to use lora to fine-tune llm.", help="Whether to use lora fine-tuned llm checkpoint.",
) )
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -247,6 +260,7 @@ def decode_one_batch(
Returns: Returns:
Return a dict, whose key may be "beam-search". Return a dict, whose key may be "beam-search".
""" """
def preprocess( def preprocess(
messages, messages,
tokenizer: transformers.PreTrainedTokenizer, tokenizer: transformers.PreTrainedTokenizer,
@ -268,10 +282,16 @@ def decode_one_batch(
) )
) )
max_len_texts = max([len(text) for text in texts]) max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == 'right': if tokenizer.padding_side == "right":
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts] texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else: else:
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts] texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int) input_ids = torch.tensor(texts, dtype=torch.int)
@ -302,16 +322,18 @@ def decode_one_batch(
feature_len = supervisions["num_frames"] feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype) feature_len = feature_len.to(device, dtype=dtype)
messages = [[ messages = [
[
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": ""}, {"role": "assistant", "content": ""},
]] * len(feature) ]
] * len(feature)
input_ids, attention_mask = preprocess( input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
messages, tokenizer, max_len=128
generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
) )
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return {"beam-search": hyps} return {"beam-search": hyps}
@ -497,14 +519,14 @@ def main():
if params.use_flash_attn: if params.use_flash_attn:
attn_implementation = "flash_attention_2" attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 # torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16 torch_dtype = torch.float16
tokenizer.padding_side = 'left' tokenizer.padding_side = "left"
else: else:
attn_implementation = "eager" attn_implementation = "eager"
torch_dtype = torch.float16 torch_dtype = torch.float16
tokenizer.padding_side = 'right' tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name, params.llm_path_or_name,
@ -515,23 +537,33 @@ def main():
lora_config = LoraConfig( lora_config = LoraConfig(
r=64, r=64,
lora_alpha=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"], target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
llm = get_peft_model(llm, lora_config) llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters() llm.print_trainable_parameters()
special_tokens_dict = { special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
}
tokenizer.add_special_tokens(special_tokens_dict) tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate) encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
model = SPEECH_LLM( model = SPEECH_LLM(
speech_encoder, speech_encoder,
@ -539,7 +571,6 @@ def main():
encoder_projector, encoder_projector,
) )
if params.avg > 1: if params.avg > 1:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
assert start >= 1, start assert start >= 1, start

View File

@ -1 +0,0 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -1,11 +1,20 @@
from torch import nn
import torch import torch
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother from transformers.trainer_pt_utils import LabelSmoother
from icefall.dist import get_rank
IGNORE_TOKEN_ID = LabelSmoother.ignore_index IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module): class EncoderProjector(nn.Module):
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py """
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
Args:
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
llm_dim (:obj:`int`): The dimension of the language model.
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
"""
def __init__(self, encoder_dim, llm_dim, downsample_rate=5): def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__() super().__init__()
self.downsample_rate = downsample_rate self.downsample_rate = downsample_rate
@ -22,14 +31,28 @@ class EncoderProjector(nn.Module):
seq_len = x.size(1) seq_len = x.size(1)
x = x.contiguous() x = x.contiguous()
x = x.view(batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate) x = x.view(
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
)
x = self.linear1(x) x = self.linear1(x)
x = self.relu(x) x = self.relu(x)
x = self.linear2(x) x = self.linear2(x)
return x return x
class SPEECH_LLM(nn.Module): class SPEECH_LLM(nn.Module):
"""
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
The encoder is used to extract speech features from the input speech signal.
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
The language model is used to generate the text from the speech features.
Args:
encoder (:obj:`nn.Module`): The encoder module.
llm (:obj:`nn.Module`): The language model module.
encoder_projector (:obj:`nn.Module`): The encoder projector module.
"""
def __init__( def __init__(
self, self,
encoder: nn.Module, encoder: nn.Module,
@ -41,23 +64,46 @@ class SPEECH_LLM(nn.Module):
self.llm = llm self.llm = llm
self.encoder_projector = encoder_projector self.encoder_projector = encoder_projector
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None): def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
):
"""
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
with the speech features and padding the input_ids to the maximum length of the speech features.
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
Args:
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
input_ids (:obj:`torch.Tensor`): The input ids to merge.
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
Returns:
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
"""
num_speechs, speech_len, embed_dim = speech_features.shape num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)) left_padding = not torch.sum(
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
)
# 1. Create a mask to know where special speech tokens are # 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1) num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension # Compute the maximum embed dimension
max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length max_embed_dim = (
batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id) num_special_speech_tokens.max() * (speech_len - 1)
) + sequence_length
batch_indices, non_speech_indices = torch.where(
input_ids != self.llm.config.default_speech_token_id
)
# 2. Compute the positions where text should be written # 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence. # Calculate new positions for text tokens in merged speech-text sequence.
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens. # `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
# `torch.cumsum` computes how each speech token shifts subsequent text token positions. # `torch.cumsum` computes how each speech token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1 new_token_positions = (
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
)
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1] nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding: if left_padding:
new_token_positions += nb_speech_pad[:, None] # offset for left padding new_token_positions += nb_speech_pad[:, None] # offset for left padding
@ -65,14 +111,24 @@ class SPEECH_LLM(nn.Module):
# 3. Create the full embedding, already padded to the maximum position # 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros( final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
) )
final_attention_mask = torch.zeros( final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
) )
if labels is not None: if labels is not None:
final_labels = torch.full( final_labels = torch.full(
(batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device (batch_size, max_embed_dim),
IGNORE_TOKEN_ID,
dtype=input_ids.dtype,
device=input_ids.device,
) )
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device. # set the corresponding tensors into their correct target device.
@ -86,17 +142,28 @@ class SPEECH_LLM(nn.Module):
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"] # 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_speech_indices] final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_speech_indices] batch_indices, non_speech_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_speech_indices
]
if labels is not None: if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_speech_indices] final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_speech_indices
]
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835) # 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
speech_to_overwrite = torch.full( speech_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device (batch_size, max_embed_dim),
True,
dtype=torch.bool,
device=inputs_embeds.device,
) )
speech_to_overwrite[batch_indices, text_to_overwrite] = False speech_to_overwrite[batch_indices, text_to_overwrite] = False
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[:, None].to(target_device) speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
:, None
].to(target_device)
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel(): if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
raise ValueError( raise ValueError(
@ -104,12 +171,18 @@ class SPEECH_LLM(nn.Module):
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation." f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
) )
final_embedding[speech_to_overwrite] = speech_features.contiguous().reshape(-1, embed_dim).to(target_device) final_embedding[speech_to_overwrite] = (
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
final_attention_mask |= speech_to_overwrite final_attention_mask |= speech_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
(final_attention_mask == 0), 1
)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.llm.config.pad_token_id) batch_indices, pad_indices = torch.where(
input_ids == self.llm.config.pad_token_id
)
indices_to_mask = new_token_positions[batch_indices, pad_indices] indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0 final_embedding[batch_indices, indices_to_mask] = 0
@ -119,7 +192,8 @@ class SPEECH_LLM(nn.Module):
return final_embedding, final_attention_mask, final_labels, position_ids return final_embedding, final_attention_mask, final_labels, position_ids
def forward(self, def forward(
self,
fbank: torch.Tensor = None, fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None, attention_mask: torch.Tensor = None,
@ -131,50 +205,46 @@ class SPEECH_LLM(nn.Module):
inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds = self.llm.get_input_embeddings()(input_ids)
enable_logging = False (
rank = get_rank() inputs_embeds,
attention_mask,
# log only on rank 0, training using deep labels,
if enable_logging and rank == 0: _,
print("input_ids", input_ids, input_ids.shape) ) = self._merge_input_ids_with_speech_features(
print("labels", labels, labels.shape)
print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
print("attention_mask_before", attention_mask.shape, attention_mask)
print(2333333333333333333333333333)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels speech_features, inputs_embeds, input_ids, attention_mask, labels
) )
if enable_logging and rank == 0:
print("speech_features", speech_features.shape, speech_features)
print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
print("attention_mask", attention_mask.shape, attention_mask)
print("position_ids", position_ids.shape, position_ids)
print("labels", labels, labels.shape)
print("================================================================")
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) model_outputs = self.llm(
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids) inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
)
with torch.no_grad(): with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1) preds = torch.argmax(model_outputs.logits, -1)
if enable_logging and rank == 0: acc = compute_accuracy(
print("preds", preds, preds.shape) preds.detach()[:, :-1],
print(4555555555555555555555555555555555555555555) labels.detach()[:, 1:],
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID) ignore_label=IGNORE_TOKEN_ID,
)
return model_outputs, acc return model_outputs, acc
def decode(
def decode(self, self,
fbank: torch.Tensor = None, fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None, attention_mask: torch.Tensor = None,
**kwargs **kwargs,
): ):
encoder_outs = self.encoder(fbank) encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs) speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16) speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds = self.llm.get_input_embeddings()(input_ids)
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features( (
inputs_embeds,
attention_mask,
_,
position_ids,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask speech_features, inputs_embeds, input_ids, attention_mask
) )
generated_ids = self.llm.generate( generated_ids = self.llm.generate(
@ -189,7 +259,7 @@ class SPEECH_LLM(nn.Module):
temperature=kwargs.get("temperature", 1.0), temperature=kwargs.get("temperature", 1.0),
bos_token_id=self.llm.config.bos_token_id, bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id, eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id pad_token_id=self.llm.config.pad_token_id,
) )
return generated_ids return generated_ids
@ -197,7 +267,7 @@ class SPEECH_LLM(nn.Module):
def compute_accuracy(pad_outputs, pad_targets, ignore_label): def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy. """Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
Args: Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax). pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax). pad_targets (LongTensor): Target label tensors (B, Lmax).
@ -212,4 +282,4 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
) )
denominator = torch.sum(mask) denominator = torch.sum(mask)
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type return numerator.float() / denominator.float()

View File

@ -248,8 +248,6 @@ class MultiDataset:
def aishell_train_cuts(self) -> CutSet: def aishell_train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts") logging.info("About to get multidataset train cuts")
# AISHELL-1
logging.info("Loading Aishell-1 in lazy mode") logging.info("Loading Aishell-1 in lazy mode")
aishell_cuts = load_manifest_lazy( aishell_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_train.jsonl.gz" self.fbank_dir / "aishell_cuts_train.jsonl.gz"
@ -257,11 +255,8 @@ class MultiDataset:
return aishell_cuts return aishell_cuts
def aishell_dev_cuts(self) -> CutSet: def aishell_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts") logging.info("About to get multidataset dev cuts")
# AISHELL
logging.info("Loading Aishell set in lazy mode") logging.info("Loading Aishell set in lazy mode")
aishell_dev_cuts = load_manifest_lazy( aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz" self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
@ -271,8 +266,6 @@ class MultiDataset:
def aishell_test_cuts(self) -> CutSet: def aishell_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts") logging.info("About to get multidataset test cuts")
# AISHELL
logging.info("Loading Aishell set in lazy mode") logging.info("Loading Aishell set in lazy mode")
aishell_test_cuts = load_manifest_lazy( aishell_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_test.jsonl.gz" self.fbank_dir / "aishell_cuts_test.jsonl.gz"
@ -282,12 +275,8 @@ class MultiDataset:
"aishell_test": aishell_test_cuts, "aishell_test": aishell_test_cuts,
} }
# aishell 2
def aishell2_train_cuts(self) -> CutSet: def aishell2_train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts") logging.info("About to get multidataset train cuts")
# AISHELL-2
logging.info("Loading Aishell-2 in lazy mode") logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy( aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz" self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
@ -297,8 +286,6 @@ class MultiDataset:
def aishell2_dev_cuts(self) -> CutSet: def aishell2_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts") logging.info("About to get multidataset dev cuts")
# AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode") logging.info("Loading Aishell-2 set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy( aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
@ -308,8 +295,6 @@ class MultiDataset:
def aishell2_test_cuts(self) -> CutSet: def aishell2_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts") logging.info("About to get multidataset test cuts")
# AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode") logging.info("Loading Aishell-2 set in lazy mode")
aishell2_test_cuts = load_manifest_lazy( aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz" self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
@ -321,8 +306,6 @@ class MultiDataset:
def wenetspeech_test_meeting_cuts(self) -> CutSet: def wenetspeech_test_meeting_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts") logging.info("About to get multidataset test cuts")
# WeNetSpeech
logging.info("Loading WeNetSpeech set in lazy mode") logging.info("Loading WeNetSpeech set in lazy mode")
wenetspeech_test_meeting_cuts = load_manifest_lazy( wenetspeech_test_meeting_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"

3
egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt Executable file → Normal file
View File

@ -5,9 +5,6 @@ sentencepiece
pypinyin pypinyin
tensorboard tensorboard
librosa librosa
# git+https://github.com/yuekaizhang/whisper.git
# zhconv
# WeTextProcessing
deepspeed deepspeed
transformers>=4.37.0 transformers>=4.37.0
flash-attn flash-attn

View File

@ -17,14 +17,28 @@
# limitations under the License. # limitations under the License.
""" """
Usage: Usage:
# fine-tuning with whisper and Qwen2
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
#fine-tuning with deepspeed zero stage 1 # For aishell fine-tuned whisper model
torchrun --nproc-per-node 8 ./whisper/train.py \ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \ --max-duration 200 \
--exp-dir whisper/exp_large_v2 \ --exp-dir ./whisper_llm_zh/exp_test \
--model-name large-v2 \ --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
""" """
import argparse import argparse
@ -39,36 +53,29 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed import deepspeed
import k2 import k2
# import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import transformers
import whisper import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
# from optim import Eden, ScaledAdam from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.functional import pad as pad_tensor
# from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.dist import get_rank, get_world_size
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import update_averaged_model
from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
@ -77,20 +84,15 @@ from icefall.utils import (
str2bool, str2bool,
) )
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>" DEFAULT_SPEECH_TOKEN = "<speech>"
def set_batch_count(model: nn.Module, batch_count: float) -> None: def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules(): for module in model.modules():
if hasattr(module, "batch_count"): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--llm-path-or-name", "--llm-path-or-name",
@ -109,7 +111,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-projector-ds-rate", "--encoder-projector-ds-rate",
type=int, type=int,
default=1, default=8,
help="Downsample rate for the encoder projector.", help="Downsample rate for the encoder projector.",
) )
parser.add_argument( parser.add_argument(
@ -133,6 +135,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to unfreeze llm during training.", help="Whether to unfreeze llm during training.",
) )
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -162,15 +165,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -198,26 +192,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--base-lr", type=float, default=1e-5, help="The base learning rate."
)
parser.add_argument(
"--lr-batches",
type=float,
default=5000,
help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""",
)
parser.add_argument(
"--lr-epochs",
type=float,
default=6,
help="""Number of epochs that affects how rapidly the learning rate decreases.
""",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -225,44 +199,6 @@ def get_parser():
help="The seed for random generators intended for reproducibility", help="The seed for random generators intended for reproducibility",
) )
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--inf-check",
type=str2bool,
default=False,
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=30,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
""",
)
parser.add_argument(
"--average-period",
type=int,
default=200,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""",
)
parser.add_argument( parser.add_argument(
"--use-fp16", "--use-fp16",
type=str2bool, type=str2bool,
@ -325,6 +261,7 @@ def get_params() -> AttributeDict:
return params return params
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
@ -379,10 +316,16 @@ def compute_loss(
) )
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id # padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
max_len_texts = max([len(text) for text in texts]) max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == 'right': if tokenizer.padding_side == "right":
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts] texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else: else:
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts] texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int) input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] # response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone() target_ids = input_ids.clone()
@ -391,8 +334,9 @@ def compute_loss(
# first get the indices of the tokens # first get the indices of the tokens
mask_prompt = True mask_prompt = True
if mask_prompt: if mask_prompt:
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant")) mask_indices = torch.where(
# then mask all tokens before the first token e.g. 151646 (speech), 151645 <assistant>, 198 \n input_ids == tokenizer.convert_tokens_to_ids("assistant")
)
for i in range(mask_indices[0].size(0)): for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i] row = mask_indices[0][i]
col = mask_indices[1][i] col = mask_indices[1][i]
@ -458,20 +402,13 @@ def compute_loss(
messages = [] messages = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
# message = [
# {"role": "system", "content": "你是一个能处理音频的助手。"},
# {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": text},
# ]
message = [ message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text}, {"role": "assistant", "content": text},
] ]
messages.append(message) messages.append(message)
input_ids, attention_mask, target_ids = preprocess( input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128)
messages, tokenizer, max_len=128
)
target_ids = target_ids.type(torch.LongTensor) target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor) input_ids = input_ids.type(torch.LongTensor)
@ -494,7 +431,9 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames info["acc"] = (
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
return loss, info return loss, info
@ -607,7 +546,7 @@ def train_one_epoch(
save_dir=params.exp_dir, save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={}, client_state={},
exclude_frozen_parameters=True exclude_frozen_parameters=True,
) )
if rank == 0: if rank == 0:
@ -703,10 +642,7 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
# if 'whisper' in params.speech_encoder_path_or_name:
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
# TODO: directly loading from whisper-ft checkpoint
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state speech_encoder_dim = whisper_model.dims.n_audio_state
@ -717,14 +653,14 @@ def run(rank, world_size, args):
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn: if params.use_flash_attn:
attn_implementation = "flash_attention_2" attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 # torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16 torch_dtype = torch.float16
tokenizer.padding_side = 'left' tokenizer.padding_side = "left"
else: else:
attn_implementation = "eager" attn_implementation = "eager"
torch_dtype = torch.float16 torch_dtype = torch.float16
tokenizer.padding_side = 'right' tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name, params.llm_path_or_name,
@ -741,21 +677,31 @@ def run(rank, world_size, args):
lora_config = LoraConfig( lora_config = LoraConfig(
r=64, r=64,
lora_alpha=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"], target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05, lora_dropout=0.05,
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
llm = get_peft_model(llm, lora_config) llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters() llm.print_trainable_parameters()
special_tokens_dict = { special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
}
tokenizer.add_special_tokens(special_tokens_dict) tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate) encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
model = SPEECH_LLM( model = SPEECH_LLM(
speech_encoder, speech_encoder,
@ -814,12 +760,6 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# # We only load the sampler's state dict when it loads a checkpoint
# # saved in the middle of an epoch
# sampler_state_dict = checkpoints["sampler"]
# else:
# sampler_state_dict = None
sampler_state_dict = None sampler_state_dict = None
if params.sampler_state_dict_path: if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict = torch.load(params.sampler_state_dict_path)
@ -840,13 +780,6 @@ def run(rank, world_size, args):
else: else:
tb_writer = None tb_writer = None
# if params.pretrained_model_path:
# checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
# if "model" not in checkpoint:
# model.load_state_dict(checkpoint, strict=True)
# else:
# load_checkpoint(params.pretrained_model_path, model)
logging.info(f"start training from epoch {params.start_epoch}") logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1): for epoch in range(params.start_epoch, params.num_epochs + 1):
@ -871,12 +804,11 @@ def run(rank, world_size, args):
rank=rank, rank=rank,
) )
model.save_checkpoint( model.save_checkpoint(
save_dir=params.exp_dir, save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}", tag=f"epoch-{params.cur_epoch}",
client_state={}, client_state={},
exclude_frozen_parameters=True exclude_frozen_parameters=True,
) )
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
@ -887,13 +819,16 @@ def run(rank, world_size, args):
) )
# save sampler state dict into checkpoint # save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict() sampler_state_dict = train_dl.sampler.state_dict()
torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt") torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
logging.info("Done!") logging.info("Done!")
def display_and_save_batch( def display_and_save_batch(
batch: dict, batch: dict,
params: AttributeDict, params: AttributeDict,