Merge branch 'k2-fsa:master' into dev/k2ssl

This commit is contained in:
Yifan Yang 2025-05-11 00:33:35 +08:00 committed by GitHub
commit 260d37b65a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 74 additions and 166 deletions

View File

@ -42,8 +42,8 @@ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishel
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
# huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-cli download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
# First, we only train the projector and freeze other modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
@ -55,9 +55,10 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora False --unfreeze-llm False
--use-lora False \
--unfreeze-llm False
# Then we jointly train the projector and LLM LoRA modules.
# Then, we jointly train the projector and LLM LoRA modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
@ -67,7 +68,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
--use-lora True \
--unfreeze-llm True \
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
```
@ -77,11 +79,11 @@ mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
@ -94,5 +96,6 @@ python3 ./whisper_llm_zh/decode.py \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
--use-lora True \
--dataset aishell
```

13
egs/speech_llm/ASR_LLM/prepare.sh Normal file → Executable file
View File

@ -7,6 +7,9 @@ set -eou pipefail
stage=0
stop_stage=0
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
@ -23,7 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# pip install huggingface_hub['cli']
# for aishell 1
huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse
huggingface-cli download --repo-type dataset --local-dir data yuekai/aishell_whisper_fbank_lhotse
fi
@ -31,9 +34,9 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
# for multi-hans-zh
huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@ -41,6 +44,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# for speechio test sets
mkdir data_speechio
huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio
huggingface-cli download --repo-type model --local-dir data_speechio yuekai/icefall_asr_speechio
mv data_speechio/fbank/* data/fbank
fi

View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
@ -357,43 +357,6 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "beam-search".
"""
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
results = []
num_cuts = 0
@ -406,6 +369,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
@ -418,12 +382,8 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split()
print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}")
this_batch.append((cut_id, ref_words, hyp_words))
for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_text))
results[lm_scale].extend(this_batch)
@ -439,40 +399,38 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# The following prints out CERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
@ -495,9 +453,13 @@ def main():
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / f"{params.method}"
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
params.res_dir
/ f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
)
logging.info("Decoding started")
@ -574,23 +536,20 @@ def main():
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt"
f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
for epoch in range(start, params.epoch + 1)
]
avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(avg_checkpoint, filename)
# filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
# torch.save(avg_checkpoint, filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
map_location="cpu",
)
model.load_state_dict(checkpoint, strict=False)
@ -643,8 +602,7 @@ def main():
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
# 2025 Yifan Yang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -42,47 +43,32 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
"""
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple
import deepspeed
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
import whisper
from asr_datamodule import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
filter_uneven_sized_batch,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
DEFAULT_SPEECH_TOKEN = "<speech>"
@ -286,13 +272,6 @@ def compute_loss(
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
# We set allowed_excess_duration_ratio=0.1.
def preprocess(
messages,
@ -347,46 +326,6 @@ def compute_loss(
return input_ids, attention_mask, target_ids
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = next(model.parameters()).device
feature = batch["inputs"]
@ -397,11 +336,10 @@ def compute_loss(
batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [normalize_text_alimeeting(text) for text in texts]
messages = []
for i, text in enumerate(texts):
text = text.replace(" ", "")
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text},
@ -516,14 +454,17 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.encoder_projector.train()
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -533,6 +474,9 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -648,7 +592,6 @@ def run(rank, world_size, args):
speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
@ -671,7 +614,6 @@ def run(rank, world_size, args):
if not params.unfreeze_llm:
for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
else:
if params.use_lora:
lora_config = LoraConfig(
@ -728,7 +670,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
model.to(device)
assert params.deepspeed and world_size > 1
assert params.deepspeed
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()
@ -764,7 +706,7 @@ def run(rank, world_size, args):
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
@ -806,15 +748,15 @@ def run(rank, world_size, args):
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
tag=f"zero-epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
f"{params.exp_dir}/epoch-{params.cur_epoch}",
tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
@ -824,7 +766,7 @@ def run(rank, world_size, args):
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!")
@ -865,6 +807,7 @@ def main():
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)

View File

@ -422,7 +422,7 @@ def compute_loss(
texts = batch["supervisions"]["text"]
unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp)
y = convert_texts_into_ids(texts, sp=sp)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):

View File

@ -397,7 +397,7 @@ def compute_loss(
texts = batch["supervisions"]["text"]
unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp)
y = convert_texts_into_ids(texts, sp=sp)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):