mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix speech_llm recipe (#1936)
* fix training/decoding scripts, cleanup unused code, and ensure compliance with style checks --------- Co-authored-by: Your Name <you@example.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
parent
cc2e64a6aa
commit
cd7caf12df
@ -55,7 +55,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora False --unfreeze-llm False
|
||||
--use-lora False \
|
||||
--unfreeze-llm False
|
||||
|
||||
# Then, we jointly train the projector and LLM LoRA modules.
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
@ -67,7 +68,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True
|
||||
--use-lora True \
|
||||
--unfreeze-llm True \
|
||||
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
|
||||
```
|
||||
|
||||
@ -77,11 +79,11 @@ mkdir -p models/whisper models/qwen models/checkpoint
|
||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
|
||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||
@ -94,5 +96,6 @@ python3 ./whisper_llm_zh/decode.py \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --dataset aishell
|
||||
--use-lora True \
|
||||
--dataset aishell
|
||||
```
|
||||
|
@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -357,43 +357,6 @@ def decode_dataset(
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
|
||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
if normalize == "none":
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
text = text.replace("<->", "")
|
||||
text = text.replace("<$>", "")
|
||||
text = text.replace("<#>", "")
|
||||
text = text.replace("<_>", "")
|
||||
text = text.replace("<space>", "")
|
||||
text = text.replace("`", "")
|
||||
text = text.replace("&", "")
|
||||
text = text.replace(",", "")
|
||||
if re.search("[a-zA-Z]", text):
|
||||
text = text.upper()
|
||||
text = text.replace("A", "A")
|
||||
text = text.replace("a", "A")
|
||||
text = text.replace("b", "B")
|
||||
text = text.replace("c", "C")
|
||||
text = text.replace("k", "K")
|
||||
text = text.replace("t", "T")
|
||||
text = text.replace(",", "")
|
||||
text = text.replace("丶", "")
|
||||
text = text.replace("。", "")
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
return text
|
||||
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
@ -406,6 +369,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list("".join(text.split())) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
@ -418,12 +382,8 @@ def decode_dataset(
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_text = normalize_text_alimeeting(ref_text)
|
||||
ref_words = ref_text.split()
|
||||
print(f"ref: {ref_text}")
|
||||
print(f"hyp: {''.join(hyp_words)}")
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_text))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
@ -439,40 +399,38 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# The following prints out CERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=True,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
@ -495,9 +453,13 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.res_dir = params.exp_dir / f"{params.method}"
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||
params.res_dir
|
||||
/ f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
@ -574,23 +536,20 @@ def main():
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg + 1
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
assert "model" not in checkpoint
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
avg_checkpoint = average_checkpoints(filenames)
|
||||
model.load_state_dict(avg_checkpoint, strict=False)
|
||||
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(avg_checkpoint, filename)
|
||||
# filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
# torch.save(avg_checkpoint, filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
|
||||
map_location="cpu",
|
||||
)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
@ -643,8 +602,7 @@ def main():
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
# 2024 Yuekai Zhang
|
||||
# 2025 Yifan Yang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -42,47 +43,32 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import deepspeed
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||
from multi_dataset import MultiDataset
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.dist import get_rank, get_world_size
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
|
||||
@ -286,13 +272,6 @@ def compute_loss(
|
||||
Returns:
|
||||
Return a tuple of two elements. The first element is the loss tensor.
|
||||
"""
|
||||
# For the uneven-sized batch, the total duration after padding would possibly
|
||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||
# we simply drop the last few shortest samples, so that the retained total frames
|
||||
# (after padding) would not exceed `allowed_max_frames`:
|
||||
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
|
||||
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
|
||||
# We set allowed_excess_duration_ratio=0.1.
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
@ -347,46 +326,6 @@ def compute_loss(
|
||||
|
||||
return input_ids, attention_mask, target_ids
|
||||
|
||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
if normalize == "none":
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
text = text.replace("<->", "")
|
||||
text = text.replace("<$>", "")
|
||||
text = text.replace("<#>", "")
|
||||
text = text.replace("<_>", "")
|
||||
text = text.replace("<space>", "")
|
||||
text = text.replace("`", "")
|
||||
text = text.replace("&", "")
|
||||
text = text.replace(",", "")
|
||||
if re.search("[a-zA-Z]", text):
|
||||
text = text.upper()
|
||||
text = text.replace("A", "A")
|
||||
text = text.replace("a", "A")
|
||||
text = text.replace("b", "B")
|
||||
text = text.replace("c", "C")
|
||||
text = text.replace("k", "K")
|
||||
text = text.replace("t", "T")
|
||||
text = text.replace(",", "")
|
||||
text = text.replace("丶", "")
|
||||
text = text.replace("。", "")
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
return text
|
||||
|
||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
|
||||
@ -397,11 +336,10 @@ def compute_loss(
|
||||
batch_idx_train = params.batch_idx_train
|
||||
supervisions = batch["supervisions"]
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in texts
|
||||
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||
|
||||
messages = []
|
||||
for i, text in enumerate(texts):
|
||||
text = text.replace(" ", "")
|
||||
message = [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||
{"role": "assistant", "content": text},
|
||||
@ -516,14 +454,17 @@ def train_one_epoch(
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.encoder_projector.train()
|
||||
model.train()
|
||||
model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -533,6 +474,9 @@ def train_one_epoch(
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
@ -648,7 +592,6 @@ def run(rank, world_size, args):
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
for name, param in speech_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
speech_encoder.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
if params.use_flash_attn:
|
||||
@ -671,7 +614,6 @@ def run(rank, world_size, args):
|
||||
if not params.unfreeze_llm:
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
llm.eval()
|
||||
else:
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
@ -728,7 +670,7 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
assert params.deepspeed and world_size > 1
|
||||
assert params.deepspeed
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=params, model=model, model_parameters=model.parameters()
|
||||
@ -764,7 +706,7 @@ def run(rank, world_size, args):
|
||||
if params.sampler_state_dict_path:
|
||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||
sampler_state_dict["max_duration"] = params.max_duration
|
||||
# TODO: load sampler state dict
|
||||
|
||||
train_dl = data_module.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
@ -806,15 +748,15 @@ def run(rank, world_size, args):
|
||||
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
tag=f"zero-epoch-{params.cur_epoch}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}",
|
||||
tag=f"zero-epoch-{params.cur_epoch}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
@ -824,7 +766,7 @@ def run(rank, world_size, args):
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
||||
)
|
||||
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
@ -865,6 +807,7 @@ def main():
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user