fix train/eval mode

reformat

reformat

fix
This commit is contained in:
yfyeung 2025-04-28 09:10:12 +00:00 committed by Your Name
parent 59c577f4ef
commit f5d2aa1f5d

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang # 2024 Yuekai Zhang
# 2025 Yifan Yang
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -42,38 +43,29 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
""" """
import argparse import argparse
import copy
import logging import logging
import os import os
import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from typing import Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed import deepspeed
import k2
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import transformers import transformers
import whisper import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft import LoraConfig, get_peft_model
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
from icefall.dist import get_rank, get_world_size from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
@ -516,7 +508,10 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should The rank of the node in DDP training. If no DDP is used, it should
be set to 0. be set to 0.
""" """
model.encoder_projector.train() model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
@ -533,6 +528,9 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -648,7 +646,6 @@ def run(rank, world_size, args):
speech_encoder_dim = whisper_model.dims.n_audio_state speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters(): for name, param in speech_encoder.named_parameters():
param.requires_grad = False param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn: if params.use_flash_attn:
@ -671,7 +668,6 @@ def run(rank, world_size, args):
if not params.unfreeze_llm: if not params.unfreeze_llm:
for name, param in llm.named_parameters(): for name, param in llm.named_parameters():
param.requires_grad = False param.requires_grad = False
llm.eval()
else: else:
if params.use_lora: if params.use_lora:
lora_config = LoraConfig( lora_config = LoraConfig(
@ -728,7 +724,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
model.to(device) model.to(device)
assert params.deepspeed and world_size > 1 assert params.deepspeed
logging.info("Using DeepSpeed") logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize( model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters() args=params, model=model, model_parameters=model.parameters()
@ -865,6 +861,7 @@ def main():
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args) run(rank=rank, world_size=world_size, args=args)