migrate from speech llm

This commit is contained in:
root 2025-02-26 04:53:23 +00:00
parent 64c5364085
commit cca562d538
9 changed files with 2655 additions and 0 deletions

View File

@ -0,0 +1,185 @@
#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
# Copyright 2023 Xiaomi Corp. (Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
import torch
from datasets import load_dataset
from lhotse import (
CutSet,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=True,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
parser.add_argument(
"--resample-to-16kHz",
type=str2bool,
default=True,
help="Resample audio to 16kHz. Default: False.",
)
parser.add_argument(
"--speed-perturb",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
parser.add_argument(
"--out-dir",
type=str,
default="data/fbank",
help="Output directory for the computed features",
)
parser.add_argument(
"--huggingface-dataset-path-or-name",
type=str,
default="/workspace/Belle_1.4M-SLAM-Omni",
help="The path or name of the Huggingface dataset",
)
parser.add_argument(
"--audio-key",
type=str,
default="question_audio",
help="The key in the Huggingface dataset containing the audio data",
)
parser.add_argument(
"--text-key",
type=str,
default="answer",
help="The key in the Huggingface dataset containing the text data",
)
return parser
def compute_fbank(args):
in_out_dir = Path(args.out_dir)
in_out_dir.mkdir(parents=True, exist_ok=True)
# number of workers in dataloader
num_workers = 4
# number of seconds in a batch
batch_duration = 10
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
if args.whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
)
else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
start = 0
stop = 1601
num_digits = 5
for i in range(start, stop):
idx = f"{i}".zfill(num_digits)
# dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition)
parquet_files = [
f"data/train-{idx}-of-01601.parquet",
]
parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
file_name = parquet_files[0]
logging.info(f"Loading dataset from {file_name}")
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key)
logging.info("Splitting cuts into smaller chunks")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
if args.resample_to_16kHz:
cut_set = cut_set.resample(16000)
if args.speed_perturb:
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{idx}",
num_workers=num_workers,
batch_duration=batch_duration,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cuts_path = f"{in_out_dir}/cuts_belle.{idx}.jsonl.gz"
logging.info(f"Saving to {cuts_path}")
# cut_set.to_file(cuts_path)
remove_recording_item(cut_set, cuts_path)
def remove_recording_item(
cuts,
output_cuts,
):
"""
don't store recording item
"""
with CutSet.open_writer(output_cuts) as writer:
for cut in cuts:
cut.recording.sources = None
writer.write(cut)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
compute_fbank(args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,45 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
set -eou pipefail
stage=2
stop_stage=2
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: "
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
python3 local/compute_whisper_fbank.py
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: "
python3 ./slam_omni/decode.py \
--max-duration 80 \
--exp-dir slam_omni/exp_test_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True
fi

View File

@ -0,0 +1,437 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
from speech_dataset import K2SpeechRecognitionDataset
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=300.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
# group.add_argument(
# "--concatenate-cuts",
# type=str2bool,
# default=False,
# help="When enabled, utterances (cuts) will be concatenated "
# "to minimize the amount of padding.",
# )
# group.add_argument(
# "--duration-factor",
# type=float,
# default=1.0,
# help="Determines the maximum duration of a concatenated cut "
# "relative to the duration of the longest cut in a batch.",
# )
# group.add_argument(
# "--gap",
# type=float,
# default=1.0,
# help="The amount of padding (in seconds) inserted between "
# "concatenated cuts. This padding is filled with noise when "
# "noise augmentation is used.",
# )
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--huggingface-dataset-path-or-name",
type=str,
default="/workspace/Belle_1.4M-SLAM-Omni",
help="The path or name of the Huggingface dataset",
)
group.add_argument(
"--audio-key",
type=str,
default="question_audio",
help="The key in the Huggingface dataset containing the audio data",
)
group.add_argument(
"--text-key",
type=str,
default="answer",
help="The key in the Huggingface dataset containing the text data",
)
group.add_argument(
"--resample-to-16kHz",
type=str2bool,
default=True,
help="Resample audio to 16kHz. Default: False.",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
# if self.args.concatenate_cuts:
# logging.info(
# f"Using cut concatenation with duration factor "
# f"{self.args.duration_factor} and gap {self.args.gap}."
# )
# # Cut concatenation should be the first transform in the list,
# # so that if we e.g. mix noise in, it will fill the gaps between
# # different utterances.
# transforms = [
# CutConcatenate(
# duration_factor=self.args.duration_factor, gap=self.args.gap
# )
# ] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
# if self.args.concatenate_cuts:
# transforms = [
# CutConcatenate(
# duration_factor=self.args.duration_factor, gap=self.args.gap
# )
# ] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda'))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(WhisperFbank(WhisperFbankConfig(num_filters=80, device='cuda')))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
if self.args.on_the_fly_feats:
# dataset = load_dataset(args.huggingface_dataset_path_or_name, streaming=True, split=partition)
i, num_digits = 0, 5
idx = f"{i}".zfill(num_digits)
parquet_files = [
f"data/train-{idx}-of-01601.parquet",
]
parquet_files = [f"{args.huggingface_dataset_path_or_name}/{f}" for f in parquet_files]
file_name = parquet_files[0]
logging.info(f"Loading dataset from {file_name}")
dataset = load_dataset('parquet', data_files=parquet_files, streaming=True, split='train')
cut_set = CutSet.from_huggingface_dataset(dataset, audio_key=args.audio_key, text_key=args.text_key)
if args.resample_to_16kHz:
cut_set = cut_set.resample(16000)
return cut_set
else:
return load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")

View File

@ -0,0 +1,653 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang,
# Wei Kang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# Command for decoding using fine-tuned models:
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
import transformers
import whisper
from data_module import AsrDataModule
from lhotse.cut import Cut
from model import SPEECH_LLM, EncoderProjector
# from data_module import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
Args:
filenames:
Filenames of the checkpoints to be averaged. We assume all
checkpoints are saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
Returns:
Return a dict (i.e., state_dict) which is the average of all
model state dicts contained in the checkpoints.
"""
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
for k in uniqued_names:
avg[k] += state_dict[k]
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else:
avg[k] //= n
return avg
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
type=str,
default="/workspace/asr/Qwen1.5-0.5B-Chat",
help="Path or name of the large language model.",
)
parser.add_argument(
"--speech-encoder-path-or-name",
type=str,
default="whisper-large-v2",
help="Path or name of the speech encoder.",
)
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=8,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser.add_argument(
"--use-lora",
type=str2bool,
default=True,
help="Whether to use lora fine-tuned llm checkpoint.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=-1,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="beam-search",
help="""Decoding method.
Supported values are:
- beam-search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=1,
help="beam size for beam search decoding",
)
parser.add_argument(
"--exp-dir",
type=str,
default="whisper/exp",
help="The experiment dir",
)
parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
type=str2bool,
default=True,
help="replace whisper encoder forward method to remove input length restriction",
)
# parser.add_argument(
# "--dataset",
# type=str,
# default="aishell",
# choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"],
# help="The dataset to decode",
# )
add_model_arguments(parser)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"env_info": get_env_info(),
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
batch: dict,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: "beam-search"
- value: A list of lists. Each sublist is a list of token IDs.
Args:
params:
It is returned by :func:`get_params`.
model:
The neural model.
batch:
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
Returns:
Return a dict, whose key may be "beam-search".
"""
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int = 128,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
add_generation_prompt=False,
chat_template=TEMPLATE,
padding="longest",
max_length=max_len,
truncation=True,
)
)
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask
dtype = torch.float32
device = model.llm.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
if not params.remove_whisper_encoder_input_length_restriction:
T = 3000
if feature.shape[2] < T:
feature = torch.cat(
[
feature,
torch.zeros(
feature.shape[0], feature.shape[1], T - feature.shape[2]
).to(device, dtype=dtype),
],
2,
)
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
messages = [
[
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": ""},
]
] * len(feature)
input_ids, attention_mask = preprocess(messages, tokenizer, max_len=128)
generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
)
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(hyps)
print(supervisions)
return {"beam-search": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
tokenizer: AutoTokenizer,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
The dataloader.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a dict, whose key may be "beam-search".
"""
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
batch=batch,
tokenizer=tokenizer,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split()
print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}")
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
)
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
logging.info(f"device: {device}")
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
)
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt"
for epoch in range(start, params.epoch + 1)
]
avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(avg_checkpoint, filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
data_module = AsrDataModule(args)
# data_module = MultiDataset(args.manifest_dir)
def remove_long_utt(c: Cut):
# Keep only utterances with duration in 30 seconds
#
if c.duration > 30.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
return True
# if params.dataset == "aishell":
# test_sets_cuts = data_module.aishell_test_cuts()
# elif params.dataset == "speechio":
# test_sets_cuts = data_module.speechio_test_cuts()
# elif params.dataset == "wenetspeech_test_meeting":
# test_sets_cuts = data_module.wenetspeech_test_meeting_cuts()
# else:
test_sets_cuts = data_module.test_cuts()
test_sets = test_sets_cuts.keys()
test_dls = [
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
for cuts_name in test_sets
]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
tokenizer=tokenizer,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -0,0 +1,285 @@
import torch
from torch import nn
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
class EncoderProjector(nn.Module):
"""
The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model.
Modified from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/models/projector.py.
Args:
encoder_dim (:obj:`int`): The dimension of the encoder outputs.
llm_dim (:obj:`int`): The dimension of the language model.
downsample_rate (:obj:`int`, `optional`, defaults to 5): The downsample rate to use.
"""
def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
super().__init__()
self.downsample_rate = downsample_rate
self.linear1 = nn.Linear(encoder_dim * self.downsample_rate, llm_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(llm_dim, llm_dim)
def forward(self, x):
batch_size, seq_len, feat_dim = x.size()
num_frames_to_discard = seq_len % self.downsample_rate
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(
batch_size, seq_len // self.downsample_rate, feat_dim * self.downsample_rate
)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
class SPEECH_LLM(nn.Module):
"""
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector.
The encoder is used to extract speech features from the input speech signal.
The encoder projector is used to project the encoder outputs to the same dimension as the language model.
The language model is used to generate the text from the speech features.
Args:
encoder (:obj:`nn.Module`): The encoder module.
llm (:obj:`nn.Module`): The language model module.
encoder_projector (:obj:`nn.Module`): The encoder projector module.
"""
def __init__(
self,
encoder: nn.Module,
llm: nn.Module,
encoder_projector: nn.Module,
):
super().__init__()
self.encoder = encoder
self.llm = llm
self.encoder_projector = encoder_projector
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
):
"""
Merge the speech features with the input_ids and attention_mask. This is done by replacing the speech tokens
with the speech features and padding the input_ids to the maximum length of the speech features.
Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L277.
Args:
speech_features (:obj:`torch.Tensor`): The speech features to merge with the input_ids.
inputs_embeds (:obj:`torch.Tensor`): The embeddings of the input_ids.
input_ids (:obj:`torch.Tensor`): The input ids to merge.
attention_mask (:obj:`torch.Tensor`): The attention mask to merge.
labels (:obj:`torch.Tensor`, `optional`): The labels to merge.
Returns:
:obj:`Tuple(torch.Tensor)`: The merged embeddings, attention mask, labels and position ids.
"""
num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
)
# 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (
num_special_speech_tokens.max() * (speech_len - 1)
) + sequence_length
batch_indices, non_speech_indices = torch.where(
input_ids != self.llm.config.default_speech_token_id
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence.
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
)
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_speech_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim),
IGNORE_TOKEN_ID,
dtype=input_ids.dtype,
device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_speech_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_speech_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_speech_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_speech_indices
]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_speech_indices
]
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
speech_to_overwrite = torch.full(
(batch_size, max_embed_dim),
True,
dtype=torch.bool,
device=inputs_embeds.device,
)
speech_to_overwrite[batch_indices, text_to_overwrite] = False
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
:, None
].to(target_device)
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
)
final_embedding[speech_to_overwrite] = (
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
final_attention_mask |= speech_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
(final_attention_mask == 0), 1
)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(
input_ids == self.llm.config.pad_token_id
)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
def forward(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
model_outputs = self.llm(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels
)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc = compute_accuracy(
preds.detach()[:, :-1],
labels.detach()[:, 1:],
ignore_label=IGNORE_TOKEN_ID,
)
return model_outputs, acc
def decode(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
**kwargs,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
speech_features = speech_features.to(torch.float16)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
attention_mask,
_,
position_ids,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask
)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 200),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", False),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id,
)
return generated_ids
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py
Args:
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return numerator.float() / denominator.float()

View File

@ -0,0 +1,176 @@
from typing import Callable, Dict, List, Union
import torch
from torch.utils.data.dataloader import DataLoader, default_collate
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from lhotse.workarounds import Hdf5MemoryIssueFix
class K2SpeechRecognitionDataset(torch.utils.data.Dataset):
"""
The PyTorch Dataset for the speech recognition task using k2 library.
This dataset expects to be queried with lists of cut IDs,
for which it loads features and automatically collates/batches them.
To use it with a PyTorch DataLoader, set ``batch_size=None``
and provide a :class:`SimpleCutSampler` sampler.
Each item in this dataset is a dict of:
.. code-block::
{
'inputs': float tensor with shape determined by :attr:`input_strategy`:
- single-channel:
- features: (B, T, F)
- audio: (B, T)
- multi-channel: currently not supported
'supervisions': [
{
'sequence_idx': Tensor[int] of shape (S,)
'text': List[str] of len S
# For feature input strategies
'start_frame': Tensor[int] of shape (S,)
'num_frames': Tensor[int] of shape (S,)
# For audio input strategies
'start_sample': Tensor[int] of shape (S,)
'num_samples': Tensor[int] of shape (S,)
# Optionally, when return_cuts=True
'cut': List[AnyCut] of len S
}
]
}
Dimension symbols legend:
* ``B`` - batch size (number of Cuts)
* ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions)
* ``T`` - number of frames of the longest Cut
* ``F`` - number of features
The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset.
"""
def __init__(
self,
return_cuts: bool = False,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
):
"""
k2 ASR IterableDataset constructor.
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
objects used to create that batch.
:param cut_transforms: A list of transforms to be applied on each sampled batch,
before converting cuts to an input representation (audio/features).
Examples: cut concatenation, noise cuts mixing, etc.
:param input_transforms: A list of transforms to be applied on each sampled batch,
after the cuts are converted to audio/features.
Examples: normalization, SpecAugment, etc.
:param input_strategy: Converts cuts into a collated batch of audio/features.
By default, reads pre-computed features from disk.
"""
super().__init__()
# Initialize the fields
self.return_cuts = return_cuts
self.cut_transforms = ifnone(cut_transforms, [])
self.input_transforms = ifnone(input_transforms, [])
self.input_strategy = input_strategy
# This attribute is a workaround to constantly growing HDF5 memory
# throughout the epoch. It regularly closes open file handles to
# reset the internal HDF5 caches.
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_duration and max_cuts.
"""
validate_for_asr(cuts)
self.hdf5_fix.update()
# Sort the cuts by duration so that the first one determines the batch time dimensions.
cuts = cuts.sort_by_duration(ascending=False)
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
# the supervision boundaries.
for tnfm in self.cut_transforms:
cuts = tnfm(cuts)
# Sort the cuts again after transforms
cuts = cuts.sort_by_duration(ascending=False)
# Get a tensor with batched feature matrices, shape (B, T, F)
# Collation performs auto-padding, if necessary.
input_tpl = self.input_strategy(cuts)
if len(input_tpl) == 3:
# An input strategy with fault tolerant audio reading mode.
# "cuts" may be a subset of the original "cuts" variable,
# that only has cuts for which we succesfully read the audio.
inputs, _, cuts = input_tpl
else:
inputs, _ = input_tpl
# Get a dict of tensors that encode the positional information about supervisions
# in the batch of feature matrices. The tensors are named "sequence_idx",
# "start_frame/sample" and "num_frames/samples".
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
# Apply all available transforms on the inputs, i.e. either audio or features.
# This could be feature extraction, global MVN, SpecAugment, etc.
segments = torch.stack(list(supervision_intervals.values()), dim=1)
for tnfm in self.input_transforms:
inputs = tnfm(inputs, supervision_segments=segments)
batch = {
"inputs": inputs,
"supervisions": default_collate(
[
{
"text": supervision.text,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
batch["supervisions"].update(supervision_intervals)
if self.return_cuts:
batch["supervisions"]["cut"] = [
cut for cut in cuts for sup in cut.supervisions
]
return batch
def validate_for_asr(cuts: CutSet) -> None:
validate(cuts)
tol = 2e-3 # 1ms
for cut in cuts:
for supervision in cut.supervisions:
assert supervision.start >= -tol, (
f"Supervisions starting before the cut are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)
# Supervision start time is relative to Cut ...
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
#
# 'supervision.end' is end of supervision inside the Cut
assert supervision.end <= cut.duration + tol, (
f"Supervisions ending after the cut "
f"are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)

View File

@ -0,0 +1,872 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# fine-tuning with whisper and Qwen2
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
"""
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
import whisper
from data_module import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
filter_uneven_sized_batch,
setup_logger,
str2bool,
)
DEFAULT_SPEECH_TOKEN = "<speech>"
def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--llm-path-or-name",
type=str,
default="/workspace/asr/Qwen1.5-0.5B-Chat",
help="Path or name of the large language model.",
)
parser.add_argument(
"--speech-encoder-path-or-name",
type=str,
default="whisper-large-v2",
help="Path or name of the speech encoder.",
)
parser.add_argument(
"--encoder-projector-ds-rate",
type=int,
default=8,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser.add_argument(
"--use-lora",
type=str2bool,
default=False,
help="Whether to use lora to fine-tune llm.",
)
parser.add_argument(
"--unfreeze-llm",
type=str2bool,
default=False,
help="Whether to unfreeze llm during training.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=10,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="whisper_qwen/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--pretrained-model-path",
type=str,
default=None,
help="""The path to the pretrained model if it is not None. Training will
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
""",
)
parser.add_argument(
"--sampler-state-dict-path",
type=str,
default=None,
help="""The path to the sampler state dict if it is not None. Training will start from this sampler state dict.
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=True,
help="Whether to use half precision training.",
)
parser.add_argument(
"--use-aishell",
type=str2bool,
default=True,
help="Whether to only use aishell1 dataset for training.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- frame_shift_ms: The frame shift in milliseconds.
- allowed_excess_duration_ratio: The allowed excess duration ratio.
- best_train_loss: The best training loss so far.
- best_valid_loss: The best validation loss so far.
- best_train_epoch: The epoch where the best training loss is achieved.
- best_valid_epoch: The epoch where the best validation loss is achieved.
- batch_idx_train: The batch index of the current batch.
- log_interval: Log training stats every `log_interval` batches.
- reset_interval: Reset the stats every `reset_interval` batches.
- valid_interval: Run validation every `valid_interval` batches.
- env_info: The environment information.
"""
params = AttributeDict(
{
"allowed_excess_duration_ratio": 0.1,
"subsampling_factor": 2,
"frame_shift_ms": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 5000,
"env_info": get_env_info(),
}
)
return params
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute the loss for the given batch.
Args:
params:
It is returned by :func:`get_params`.
tokenizer:
The tokenizer used to encode the text.
model:
The model for training.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
Whether it is training.
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
# We set allowed_excess_duration_ratio=0.1.
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
max_length=max_len,
truncation=True,
)
)
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
mask_indices = torch.where(
input_ids == tokenizer.convert_tokens_to_ids("assistant")
)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [normalize_text_alimeeting(text) for text in texts]
messages = []
for i, text in enumerate(texts):
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text},
]
messages.append(message)
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer, max_len=128)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
with torch.set_grad_enabled(is_training):
model_outputs, acc = model(
fbank=feature,
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
)
loss = model_outputs.loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
feature_lens = supervisions["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["acc"] = (
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
return loss, info
def compute_validation_loss(
params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.encoder_projector.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
tokenizer=tokenizer,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx != 0:
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# deepspeed's backward() is different from torch's backward()
# in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss)
model.step()
except: # noqa
display_and_save_batch(batch, params=params)
raise
if batch_idx % params.log_interval == 0:
try:
cur_lr = scheduler.get_last_lr()[0]
except: # noqa
cur_lr = 0.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16
tokenizer.padding_side = "left"
else:
attn_implementation = "eager"
torch_dtype = torch.float16
tokenizer.padding_side = "right"
llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if not params.unfreeze_llm:
for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
else:
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
model = SPEECH_LLM(
speech_encoder,
llm,
encoder_projector,
)
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
logging.info("Trainable parameters (excluding model.eval modules):")
for name, param in model.named_parameters():
if param.requires_grad:
logging.info(f"{name}: {param.shape}")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device)
assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()
)
data_module = AsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
return True
if params.use_aishell:
train_cuts = multi_dataset.aishell_train_cuts()
else:
train_cuts = multi_dataset.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt)
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
if params.use_aishell:
valid_cuts = multi_dataset.aishell_dev_cuts()
else:
valid_cuts = multi_dataset.dev_cuts()
valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
tokenizer=tokenizer,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
logging.info("Done!")
def display_and_save_batch(
batch: dict,
params: AttributeDict,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = get_world_size()
rank = get_rank()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py