Fix speech_llm recipe (#1936)

* fix training/decoding scripts, cleanup unused code, and ensure compliance with style checks

---------

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
Yifan Yang 2025-04-30 11:41:00 +08:00 committed by GitHub
parent cc2e64a6aa
commit cd7caf12df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 156 deletions

View File

@ -55,7 +55,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True \
--use-lora False --unfreeze-llm False --use-lora False \
--unfreeze-llm False
# Then, we jointly train the projector and LLM LoRA modules. # Then, we jointly train the projector and LLM LoRA modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
@ -67,7 +68,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --unfreeze-llm True --use-lora True \
--unfreeze-llm True \
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt --pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
``` ```
@ -94,5 +96,6 @@ python3 ./whisper_llm_zh/decode.py \
--epoch 999 --avg 1 \ --epoch 999 --avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --dataset aishell --use-lora True \
--dataset aishell
``` ```

View File

@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -357,43 +357,6 @@ def decode_dataset(
Returns: Returns:
Return a dict, whose key may be "beam-search". Return a dict, whose key may be "beam-search".
""" """
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
results = [] results = []
num_cuts = 0 num_cuts = 0
@ -406,6 +369,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
@ -418,12 +382,8 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text) this_batch.append((cut_id, ref_text, hyp_text))
ref_words = ref_text.split()
print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}")
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -439,40 +399,38 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
enable_log = True
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = ( recog_path = (
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out CERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = ( errs_filename = (
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -495,9 +453,13 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.res_dir = params.exp_dir / f"{params.method}"
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger( setup_logger(
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" params.res_dir
/ f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
) )
logging.info("Decoding started") logging.info("Decoding started")
@ -574,23 +536,20 @@ def main():
if params.avg > 1: if params.avg > 1:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
assert start >= 1, start assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict # deepspeed converted checkpoint only contains model state_dict
filenames = [ filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt" f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
for epoch in range(start, params.epoch + 1) for epoch in range(start, params.epoch + 1)
] ]
avg_checkpoint = average_checkpoints(filenames) avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False) model.load_state_dict(avg_checkpoint, strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" # filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(avg_checkpoint, filename) # torch.save(avg_checkpoint, filename)
else: else:
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
map_location="cpu",
) )
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(checkpoint, strict=False)
@ -643,8 +602,7 @@ def main():
logging.info("Done!") logging.info("Done!")
if __name__ == "__main__":
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
if __name__ == "__main__":
main() main()

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang # 2024 Yuekai Zhang
# 2025 Yifan Yang
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -42,47 +43,32 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
""" """
import argparse import argparse
import copy
import logging import logging
import os import os
import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from typing import Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed import deepspeed
import k2
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import transformers import transformers
import whisper import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft import LoraConfig, get_peft_model
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
from icefall.dist import get_rank, get_world_size from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
AttributeDict,
MetricsTracker,
filter_uneven_sized_batch,
setup_logger,
str2bool,
)
DEFAULT_SPEECH_TOKEN = "<speech>" DEFAULT_SPEECH_TOKEN = "<speech>"
@ -286,13 +272,6 @@ def compute_loss(
Returns: Returns:
Return a tuple of two elements. The first element is the loss tensor. Return a tuple of two elements. The first element is the loss tensor.
""" """
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
# We set allowed_excess_duration_ratio=0.1.
def preprocess( def preprocess(
messages, messages,
@ -347,46 +326,6 @@ def compute_loss(
return input_ids, attention_mask, target_ids return input_ids, attention_mask, target_ids
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = next(model.parameters()).device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
@ -397,11 +336,10 @@ def compute_loss(
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [normalize_text_alimeeting(text) for text in texts]
messages = [] messages = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
text = text.replace(" ", "")
message = [ message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text}, {"role": "assistant", "content": text},
@ -516,14 +454,17 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should The rank of the node in DDP training. If no DDP is used, it should
be set to 0. be set to 0.
""" """
model.encoder_projector.train() model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
@ -533,6 +474,9 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -648,7 +592,6 @@ def run(rank, world_size, args):
speech_encoder_dim = whisper_model.dims.n_audio_state speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters(): for name, param in speech_encoder.named_parameters():
param.requires_grad = False param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn: if params.use_flash_attn:
@ -671,7 +614,6 @@ def run(rank, world_size, args):
if not params.unfreeze_llm: if not params.unfreeze_llm:
for name, param in llm.named_parameters(): for name, param in llm.named_parameters():
param.requires_grad = False param.requires_grad = False
llm.eval()
else: else:
if params.use_lora: if params.use_lora:
lora_config = LoraConfig( lora_config = LoraConfig(
@ -728,7 +670,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
model.to(device) model.to(device)
assert params.deepspeed and world_size > 1 assert params.deepspeed
logging.info("Using DeepSpeed") logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize( model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters() args=params, model=model, model_parameters=model.parameters()
@ -764,7 +706,7 @@ def run(rank, world_size, args):
if params.sampler_state_dict_path: if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration sampler_state_dict["max_duration"] = params.max_duration
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders( train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
@ -806,15 +748,15 @@ def run(rank, world_size, args):
model.save_checkpoint( model.save_checkpoint(
save_dir=params.exp_dir, save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}", tag=f"zero-epoch-{params.cur_epoch}",
client_state={}, client_state={},
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}",
tag=f"epoch-{params.cur_epoch}", tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
# save sampler state dict into checkpoint # save sampler state dict into checkpoint
@ -824,7 +766,7 @@ def run(rank, world_size, args):
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
) )
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!") logging.info("Done!")
@ -865,6 +807,7 @@ def main():
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args) run(rank=rank, world_size=world_size, args=args)