From cd7caf12df92ac21f67a4027cf80b940ee15bef6 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 30 Apr 2025 11:41:00 +0800 Subject: [PATCH] Fix speech_llm recipe (#1936) * fix training/decoding scripts, cleanup unused code, and ensure compliance with style checks --------- Co-authored-by: Your Name Co-authored-by: Fangjun Kuang --- egs/speech_llm/ASR_LLM/RESULTS.md | 15 ++- .../ASR_LLM/whisper_llm_zh/decode.py | 104 ++++++------------ .../ASR_LLM/whisper_llm_zh/train.py | 97 ++++------------ 3 files changed, 60 insertions(+), 156 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md index 01c48a82e..42dce80c5 100644 --- a/egs/speech_llm/ASR_LLM/RESULTS.md +++ b/egs/speech_llm/ASR_LLM/RESULTS.md @@ -55,7 +55,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --deepspeed \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --use-flash-attn True \ - --use-lora False --unfreeze-llm False + --use-lora False \ + --unfreeze-llm False # Then, we jointly train the projector and LLM LoRA modules. torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ @@ -67,7 +68,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --deepspeed \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --use-flash-attn True \ - --use-lora True --unfreeze-llm True + --use-lora True \ + --unfreeze-llm True \ --pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt ``` @@ -77,11 +79,11 @@ mkdir -p models/whisper models/qwen models/checkpoint huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B # For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt # For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt -huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt @@ -94,5 +96,6 @@ python3 ./whisper_llm_zh/decode.py \ --epoch 999 --avg 1 \ --manifest-dir data/fbank \ --use-flash-attn True \ - --use-lora True --dataset aishell + --use-lora True \ + --dataset aishell ``` diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 882ce4fbf..3036b471e 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint +from icefall.checkpoint import load_checkpoint from icefall.env import get_env_info from icefall.utils import ( AttributeDict, @@ -357,43 +357,6 @@ def decode_dataset( Returns: Return a dict, whose key may be "beam-search". """ - - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - results = [] num_cuts = 0 @@ -406,6 +369,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( @@ -418,12 +382,8 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_text = normalize_text_alimeeting(ref_text) - ref_words = ref_text.split() - print(f"ref: {ref_text}") - print(f"hyp: {''.join(hyp_words)}") - this_batch.append((cut_id, ref_words, hyp_words)) + for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_text)) results[lm_scale].extend(this_batch) @@ -439,40 +399,38 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): - - enable_log = True test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recog_path, texts=results, char_level=True) + logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned + # The following prints out CERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -495,9 +453,13 @@ def main(): params = get_params() params.update(vars(args)) + + params.res_dir = params.exp_dir / f"{params.method}" + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" setup_logger( - f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + params.res_dir + / f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}" ) logging.info("Decoding started") @@ -574,23 +536,20 @@ def main(): if params.avg > 1: start = params.epoch - params.avg + 1 assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - assert "model" not in checkpoint # deepspeed converted checkpoint only contains model state_dict filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" + f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin" for epoch in range(start, params.epoch + 1) ] avg_checkpoint = average_checkpoints(filenames) model.load_state_dict(avg_checkpoint, strict=False) - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(avg_checkpoint, filename) + # filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + # torch.save(avg_checkpoint, filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", + map_location="cpu", ) model.load_state_dict(checkpoint, strict=False) @@ -643,8 +602,7 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 5f224c984..7947a60a5 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) # 2024 Yuekai Zhang +# 2025 Yifan Yang # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -42,47 +43,32 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ """ import argparse -import copy import logging import os -import random import warnings from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple import deepspeed -import k2 import torch -import torch.multiprocessing as mp import torch.nn as nn import transformers import whisper from asr_datamodule import AsrDataModule from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict -from label_smoothing import LabelSmoothingLoss -from lhotse import CutSet, load_manifest from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from multi_dataset import MultiDataset -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model from torch import Tensor from torch.utils.tensorboard import SummaryWriter from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall import diagnostics from icefall.dist import get_rank, get_world_size from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool DEFAULT_SPEECH_TOKEN = "" @@ -286,13 +272,6 @@ def compute_loss( Returns: Return a tuple of two elements. The first element is the loss tensor. """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. def preprocess( messages, @@ -347,46 +326,6 @@ def compute_loss( return input_ids, attention_mask, target_ids - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - device = next(model.parameters()).device feature = batch["inputs"] @@ -397,11 +336,10 @@ def compute_loss( batch_idx_train = params.batch_idx_train supervisions = batch["supervisions"] texts = batch["supervisions"]["text"] - # remove spaces in texts - texts = [normalize_text_alimeeting(text) for text in texts] messages = [] for i, text in enumerate(texts): + text = text.replace(" ", "") message = [ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "assistant", "content": text}, @@ -516,14 +454,17 @@ def train_one_epoch( The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ - model.encoder_projector.train() + model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -533,6 +474,9 @@ def train_one_epoch( world_size=world_size, ) model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -648,7 +592,6 @@ def run(rank, world_size, args): speech_encoder_dim = whisper_model.dims.n_audio_state for name, param in speech_encoder.named_parameters(): param.requires_grad = False - speech_encoder.eval() tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) if params.use_flash_attn: @@ -671,7 +614,6 @@ def run(rank, world_size, args): if not params.unfreeze_llm: for name, param in llm.named_parameters(): param.requires_grad = False - llm.eval() else: if params.use_lora: lora_config = LoraConfig( @@ -728,7 +670,7 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") model.to(device) - assert params.deepspeed and world_size > 1 + assert params.deepspeed logging.info("Using DeepSpeed") model, optimizer, _, scheduler = deepspeed.initialize( args=params, model=model, model_parameters=model.parameters() @@ -764,7 +706,7 @@ def run(rank, world_size, args): if params.sampler_state_dict_path: sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict["max_duration"] = params.max_duration - # TODO: load sampler state dict + train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) @@ -806,15 +748,15 @@ def run(rank, world_size, args): model.save_checkpoint( save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", client_state={}, exclude_frozen_parameters=True, ) if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}", + f"{params.exp_dir}/epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", exclude_frozen_parameters=True, ) # save sampler state dict into checkpoint @@ -824,7 +766,7 @@ def run(rank, world_size, args): f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", ) - os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") + os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}") logging.info("Done!") @@ -865,6 +807,7 @@ def main(): torch.set_num_threads(1) torch.set_num_interop_threads(1) + warnings.filterwarnings("ignore", category=FutureWarning) run(rank=rank, world_size=world_size, args=args)