fix train/eval mode

reformat

reformat

fix
This commit is contained in:
yfyeung 2025-04-28 09:10:12 +00:00 committed by Your Name
parent 59c577f4ef
commit f5d2aa1f5d

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
# 2025 Yifan Yang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -42,38 +43,29 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
"""
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple
import deepspeed
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
import whisper
from asr_datamodule import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info
from icefall.utils import (
@ -516,7 +508,10 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.encoder_projector.train()
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker()
@ -533,6 +528,9 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -648,7 +646,6 @@ def run(rank, world_size, args):
speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
@ -671,7 +668,6 @@ def run(rank, world_size, args):
if not params.unfreeze_llm:
for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
else:
if params.use_lora:
lora_config = LoraConfig(
@ -728,7 +724,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
model.to(device)
assert params.deepspeed and world_size > 1
assert params.deepspeed
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()
@ -865,6 +861,7 @@ def main():
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)