diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 239080014..9e1646808 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) # 2024 Yuekai Zhang +# 2025 Yifan Yang # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -42,38 +43,29 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ """ import argparse -import copy import logging import os -import random import warnings from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple import deepspeed -import k2 import torch -import torch.multiprocessing as mp import torch.nn as nn import transformers import whisper from asr_datamodule import AsrDataModule from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict -from label_smoothing import LabelSmoothingLoss -from lhotse import CutSet, load_manifest from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from multi_dataset import MultiDataset -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model from torch import Tensor from torch.utils.tensorboard import SummaryWriter from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall import diagnostics from icefall.dist import get_rank, get_world_size from icefall.env import get_env_info from icefall.utils import ( @@ -516,7 +508,10 @@ def train_one_epoch( The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ - model.encoder_projector.train() + model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() tot_loss = MetricsTracker() @@ -533,6 +528,9 @@ def train_one_epoch( world_size=world_size, ) model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -648,7 +646,6 @@ def run(rank, world_size, args): speech_encoder_dim = whisper_model.dims.n_audio_state for name, param in speech_encoder.named_parameters(): param.requires_grad = False - speech_encoder.eval() tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) if params.use_flash_attn: @@ -671,7 +668,6 @@ def run(rank, world_size, args): if not params.unfreeze_llm: for name, param in llm.named_parameters(): param.requires_grad = False - llm.eval() else: if params.use_lora: lora_config = LoraConfig( @@ -728,7 +724,7 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") model.to(device) - assert params.deepspeed and world_size > 1 + assert params.deepspeed logging.info("Using DeepSpeed") model, optimizer, _, scheduler = deepspeed.initialize( args=params, model=model, model_parameters=model.parameters() @@ -865,6 +861,7 @@ def main(): torch.set_num_threads(1) torch.set_num_interop_threads(1) + warnings.filterwarnings("ignore", category=FutureWarning) run(rank=rank, world_size=world_size, args=args)