add readme

This commit is contained in:
root 2024-06-14 03:55:57 +00:00 committed by Yuekai Zhang
parent d1e31c7ac7
commit 9ed428d7b1
10 changed files with 397 additions and 409 deletions

View File

@ -1,39 +1,20 @@
# Introduction
This recipe includes scripts for training Zipformer model using multiple Chinese datasets.
This recipe includes scripts for training [Qwen-Audio](https://github.com/QwenLM/Qwen-Audio/tree/main) style model using multiple datasets.
# Included Training Sets
1. THCHS-30
2. AiShell-{1,2,4}
3. ST-CMDS
4. Primewords
5. MagicData
6. Aidatatang_200zh
7. AliMeeting
8. WeNetSpeech
9. KeSpeech-ASR
<br>
<p align="center">
<img src="assets/framework.png" width="800"/>
<p>
<br>
|Datset| Number of hours| URL|
|---|---:|---|
|**TOTAL**|14,106|---|
|THCHS-30|35|https://www.openslr.org/18/|
|AiShell-1|170|https://www.openslr.org/33/|
|AiShell-2|1,000|http://www.aishelltech.com/aishell_2|
|AiShell-4|120|https://www.openslr.org/111/|
|ST-CMDS|110|https://www.openslr.org/38/|
|Primewords|99|https://www.openslr.org/47/|
|aidatatang_200zh|200|https://www.openslr.org/62/|
|MagicData|755|https://www.openslr.org/68/|
|AliMeeting|100|https://openslr.org/119/|
|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech|
|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech|
[./RESULTS.md](./RESULTS.md) contains the latest results.
# ASR_LLM
# Included Test Sets
1. Aishell-{1,2,4}
2. Aidatatang_200zh
3. AliMeeting
4. MagicData
5. KeSpeech-ASR
6. WeNetSpeech
The following table lists the folders for different tasks.
| | Speech Encoder | LLM | Comment |
|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
| [whisper_llm_zh](./whisper_llm_zh) | Whisper | Qwen2 | [Using multiple Chinese datasets](https://github.com/k2-fsa/icefall/tree/master/egs/multi_zh-hans/ASR) |

View File

@ -1,116 +1,62 @@
## Results
### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
#### Whisper
[./whisper](./whisper)
### whisper_llm_zh finetuning results
Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
| Training Dataset | Speech Encoder | LLM | Projector |Comment | CER |
| -------------------------| ----------------|------|--------------------------------------------------|-----|--|
| Aishell1 | whisper-large-v2-aishell1-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| [yuekai/icefall_asr_aishell_whisper_qwen2_1.5B](https://huggingface.co/yuekai/icefall_asr_aishell_whisper_qwen2_1.5B) | Aishell1 Test 3.76% |
<!-- | Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-1.5B-Instruct, LoRA | Linear, 8x downsample| WIP ||
| Multi-hans-zh | whisper-large-v2-multi-hans-ft, freeze| Qwen2-7B-Instruct, LoRA | Linear, 8x downsample| WIP || -->
Command for training is:
```bash
pip install -r whisper/requirements.txt
pip install -r whisper_llm_zh/requirements.txt
# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
torchrun --nproc-per-node 8 ./whisper/train.py \
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
```
Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt
mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
```
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
#### Non-streaming (with CTC head)
Best results (num of params : ~69M):
The training command:
```
./zipformer/train.py \
--world-size 4 \
--num-epochs 20 \
--use-fp16 1 \
--max-duration 600 \
--num-workers 8 \
--use-ctc 1
```
The decoding command:
```
./zipformer/decode.py \
--epoch 20 \
--avg 1 \
--use-ctc 1
```
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled).
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 |
| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 |
Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/
#### Non-streaming
Best results (num of params : ~69M):
The training command:
```
./zipformer/train.py \
--world-size 4 \
--num-epochs 20 \
--use-fp16 1 \
--max-duration 600 \
--num-workers 8
```
The decoding command:
```
./zipformer/decode.py \
--epoch 20 \
--avg 1
```
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled).
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |
Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/

Binary file not shown.

After

Width:  |  Height:  |  Size: 834 KiB

View File

@ -0,0 +1,46 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=0
stop_stage=0
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download whisper-large-v2 aishell 1 fbank feature from huggingface"
# pip install huggingface_hub['cli']
# for aishell 1
huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
# for multi-hans-zh
huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: Download whisper-large-v2 speechio test sets fbank feature from huggingface"
# for speechio test sets
mkdir data_speechio
huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio
mv data_speechio/fbank/* data/fbank
fi

View File

@ -20,21 +20,34 @@
"""
Usage:
# Command for decoding using fine-tuned models:
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
"""
import argparse
import logging
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -42,18 +55,17 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
import transformers
import whisper
from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
#from tn.chinese.normalizer import Normalizer
#from whisper.normalizers import BasicTextNormalizer
#from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
#from zhconv import convert
import transformers
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer
from model import EncoderProjector, SPEECH_LLM
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
@ -63,8 +75,7 @@ from icefall.utils import (
str2bool,
write_error_stats,
)
from train import DEFAULT_SPEECH_TOKEN
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
@ -117,6 +128,7 @@ def average_checkpoints(
return avg
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
@ -135,7 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=1,
default=8,
help="Downsample rate for the encoder projector.",
)
@ -149,10 +161,11 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--use-lora",
type=str2bool,
default=False,
help="Whether to use lora to fine-tune llm.",
default=True,
help="Whether to use lora fine-tuned llm checkpoint.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -247,6 +260,7 @@ def decode_one_batch(
Returns:
Return a dict, whose key may be "beam-search".
"""
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
@ -268,10 +282,16 @@ def decode_one_batch(
)
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == 'right':
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
@ -302,16 +322,18 @@ def decode_one_batch(
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
messages = [[
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": ""},
]] * len(feature)
messages = [
[
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": ""},
]
] * len(feature)
input_ids, attention_mask = preprocess(
messages, tokenizer, max_len=128
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device))
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return {"beam-search": hyps}
@ -497,14 +519,14 @@ def main():
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16
torch_dtype=torch.float16
tokenizer.padding_side = 'left'
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype=torch.float16
tokenizer.padding_side = 'right'
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
@ -515,23 +537,33 @@ def main():
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
}
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
model = SPEECH_LLM(
speech_encoder,
@ -539,7 +571,6 @@ def main():
encoder_projector,
)
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
@ -579,7 +610,7 @@ def main():
#
if c.duration > 30.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
return True

View File

@ -1 +0,0 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -1,11 +1,20 @@
from torch import nn
import torch
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
from icefall.dist import get_rank
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module):
# https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py
"""
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
Args:
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
llm_dim (:obj:`int`): The dimension of the language model.
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
"""
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__()
self.downsample_rate = downsample_rate
@ -22,14 +31,28 @@ class EncoderProjector(nn.Module):
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate)
x = x.view(
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
class SPEECH_LLM(nn.Module):
"""
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
The encoder is used to extract speech features from the input speech signal.
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
The language model is used to generate the text from the speech features.
Args:
encoder (:obj:`nn.Module`): The encoder module.
llm (:obj:`nn.Module`): The language model module.
encoder_projector (:obj:`nn.Module`): The encoder projector module.
"""
def __init__(
self,
encoder: nn.Module,
@ -41,23 +64,46 @@ class SPEECH_LLM(nn.Module):
self.llm = llm
self.encoder_projector = encoder_projector
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
):
"""
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
with the speech features and padding the input_ids to the maximum length of the speech features.
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
Args:
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
input_ids (:obj:`torch.Tensor`): The input ids to merge.
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
Returns:
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
"""
num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id))
left_padding = not torch.sum(
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
)
# 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_speech_tokens.max() * (speech_len - 1)) + sequence_length
batch_indices, non_speech_indices = torch.where(input_ids != self.llm.config.default_speech_token_id)
max_embed_dim = (
num_special_speech_tokens.max() * (speech_len - 1)
) + sequence_length
batch_indices, non_speech_indices = torch.where(
input_ids != self.llm.config.default_speech_token_id
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence.
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
new_token_positions = (
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
)
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_speech_pad[:, None] # offset for left padding
@ -65,14 +111,24 @@ class SPEECH_LLM(nn.Module):
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), IGNORE_TOKEN_ID, dtype=input_ids.dtype, device=input_ids.device
(batch_size, max_embed_dim),
IGNORE_TOKEN_ID,
dtype=input_ids.dtype,
device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
@ -86,17 +142,28 @@ class SPEECH_LLM(nn.Module):
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_speech_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_speech_indices]
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_speech_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_speech_indices
]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_speech_indices]
final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_speech_indices
]
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
speech_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
(batch_size, max_embed_dim),
True,
dtype=torch.bool,
device=inputs_embeds.device,
)
speech_to_overwrite[batch_indices, text_to_overwrite] = False
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[:, None].to(target_device)
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
:, None
].to(target_device)
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
raise ValueError(
@ -104,12 +171,18 @@ class SPEECH_LLM(nn.Module):
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
)
final_embedding[speech_to_overwrite] = speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_embedding[speech_to_overwrite] = (
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
final_attention_mask |= speech_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
(final_attention_mask == 0), 1
)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.llm.config.pad_token_id)
batch_indices, pad_indices = torch.where(
input_ids == self.llm.config.pad_token_id
)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
@ -119,62 +192,59 @@ class SPEECH_LLM(nn.Module):
return final_embedding, final_attention_mask, final_labels, position_ids
def forward(self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
):
def forward(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
enable_logging = False
rank = get_rank()
# log only on rank 0, training using deep
if enable_logging and rank == 0:
print("input_ids", input_ids, input_ids.shape)
print("labels", labels, labels.shape)
print("inputs_embeds", inputs_embeds.shape, inputs_embeds)
print("attention_mask_before", attention_mask.shape, attention_mask)
print(2333333333333333333333333333)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_speech_features(
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
if enable_logging and rank == 0:
print("speech_features", speech_features.shape, speech_features)
print("inputs_embeds after", inputs_embeds.shape, inputs_embeds)
print("attention_mask", attention_mask.shape, attention_mask)
print("position_ids", position_ids.shape, position_ids)
print("labels", labels, labels.shape)
print("================================================================")
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
# model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
if enable_logging and rank == 0:
print("preds", preds, preds.shape)
print(4555555555555555555555555555555555555555555)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
return model_outputs, acc
def decode(self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
**kwargs
):
def decode(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
**kwargs,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
inputs_embeds, attention_mask, _, position_ids = self._merge_input_ids_with_speech_features(
(
inputs_embeds,
attention_mask,
_,
position_ids,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
generated_ids = self.llm.generate(
@ -189,7 +259,7 @@ class SPEECH_LLM(nn.Module):
temperature=kwargs.get("temperature", 1.0),
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id
pad_token_id=self.llm.config.pad_token_id,
)
return generated_ids
@ -197,7 +267,7 @@ class SPEECH_LLM(nn.Module):
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
@ -212,4 +282,4 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label):
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
return numerator.float() / denominator.float()

View File

@ -248,8 +248,6 @@ class MultiDataset:
def aishell_train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# AISHELL-1
logging.info("Loading Aishell-1 in lazy mode")
aishell_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
@ -257,11 +255,8 @@ class MultiDataset:
return aishell_cuts
def aishell_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# AISHELL
logging.info("Loading Aishell set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
@ -271,8 +266,6 @@ class MultiDataset:
def aishell_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
# AISHELL
logging.info("Loading Aishell set in lazy mode")
aishell_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
@ -282,12 +275,8 @@ class MultiDataset:
"aishell_test": aishell_test_cuts,
}
# aishell 2
def aishell2_train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# AISHELL-2
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
@ -297,8 +286,6 @@ class MultiDataset:
def aishell2_dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
@ -308,8 +295,6 @@ class MultiDataset:
def aishell2_test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
# AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
@ -321,8 +306,6 @@ class MultiDataset:
def wenetspeech_test_meeting_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
# WeNetSpeech
logging.info("Loading WeNetSpeech set in lazy mode")
wenetspeech_test_meeting_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"

3
egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt Executable file → Normal file
View File

@ -5,9 +5,6 @@ sentencepiece
pypinyin
tensorboard
librosa
# git+https://github.com/yuekaizhang/whisper.git
# zhconv
# WeTextProcessing
deepspeed
transformers>=4.37.0
flash-attn

View File

@ -17,14 +17,28 @@
# limitations under the License.
"""
Usage:
# fine-tuning with whisper and Qwen2
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
#fine-tuning with deepspeed zero stage 1
torchrun --nproc-per-node 8 ./whisper/train.py \
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
"""
import argparse
@ -39,36 +53,29 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import k2
# import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
import whisper
from asr_datamodule import AsrDataModule
from model import SPEECH_LLM, EncoderProjector, IGNORE_TOKEN_ID
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
# from optim import Eden, ScaledAdam
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.functional import pad as pad_tensor
# from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import update_averaged_model
from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
@ -77,20 +84,15 @@ from icefall.utils import (
str2bool,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>"
def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
@ -109,7 +111,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=1,
default=8,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
@ -133,6 +135,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to unfreeze llm during training.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -162,15 +165,6 @@ def get_parser():
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -198,26 +192,6 @@ def get_parser():
""",
)
parser.add_argument(
"--base-lr", type=float, default=1e-5, help="The base learning rate."
)
parser.add_argument(
"--lr-batches",
type=float,
default=5000,
help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""",
)
parser.add_argument(
"--lr-epochs",
type=float,
default=6,
help="""Number of epochs that affects how rapidly the learning rate decreases.
""",
)
parser.add_argument(
"--seed",
type=int,
@ -225,44 +199,6 @@ def get_parser():
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--inf-check",
type=str2bool,
default=False,
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=30,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
""",
)
parser.add_argument(
"--average-period",
type=int,
default=200,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
@ -325,6 +261,7 @@ def get_params() -> AttributeDict:
return params
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
@ -372,17 +309,23 @@ def compute_loss(
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
padding="longest", # FIX me change padding to longest
max_length=max_len,
truncation=True,
)
)
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == 'right':
texts = [text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) for text in texts]
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text for text in texts]
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone()
@ -391,13 +334,14 @@ def compute_loss(
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
mask_indices = torch.where(input_ids == tokenizer.convert_tokens_to_ids("assistant"))
# then mask all tokens before the first token e.g. 151646 (speech), 151645 <assistant>, 198 \n
mask_indices = torch.where(
input_ids == tokenizer.convert_tokens_to_ids("assistant")
)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
target_ids[row, :col+2] = IGNORE_TOKEN_ID
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
@ -458,20 +402,13 @@ def compute_loss(
messages = []
for i, text in enumerate(texts):
# message = [
# {"role": "system", "content": "你是一个能处理音频的助手。"},
# {"role": "user", "content": f"请转写音频为文字 {DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": text},
# ]
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text},
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text},
]
messages.append(message)
input_ids, attention_mask, target_ids = preprocess(
messages, tokenizer, max_len=128
)
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
@ -494,7 +431,9 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["acc"] = acc * info["frames"] # WAR: to avoid normalization by the number of frames
info["acc"] = (
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
return loss, info
@ -607,7 +546,7 @@ def train_one_epoch(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={},
exclude_frozen_parameters=True
exclude_frozen_parameters=True,
)
if rank == 0:
@ -703,10 +642,7 @@ def run(rank, world_size, args):
logging.info("About to create model")
# if 'whisper' in params.speech_encoder_path_or_name:
replace_whisper_encoder_forward()
# TODO: directly loading from whisper-ft checkpoint
# whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
@ -717,14 +653,14 @@ def run(rank, world_size, args):
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16
torch_dtype=torch.float16
tokenizer.padding_side = 'left'
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype=torch.float16
tokenizer.padding_side = 'right'
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
@ -741,21 +677,31 @@ def run(rank, world_size, args):
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
}
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
model = SPEECH_LLM(
speech_encoder,
@ -814,12 +760,6 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# # We only load the sampler's state dict when it loads a checkpoint
# # saved in the middle of an epoch
# sampler_state_dict = checkpoints["sampler"]
# else:
# sampler_state_dict = None
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
@ -840,13 +780,6 @@ def run(rank, world_size, args):
else:
tb_writer = None
# if params.pretrained_model_path:
# checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
# if "model" not in checkpoint:
# model.load_state_dict(checkpoint, strict=True)
# else:
# load_checkpoint(params.pretrained_model_path, model)
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
@ -871,12 +804,11 @@ def run(rank, world_size, args):
rank=rank,
)
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
@ -887,13 +819,16 @@ def run(rank, world_size, args):
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(sampler_state_dict, f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt")
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
logging.info("Done!")
def display_and_save_batch(
batch: dict,
params: AttributeDict,