mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
fix train/eval mode
reformat reformat fix
This commit is contained in:
parent
59c577f4ef
commit
f5d2aa1f5d
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
# 2024 Yuekai Zhang
|
||||
# 2025 Yifan Yang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -42,38 +43,29 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import deepspeed
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||
from multi_dataset import MultiDataset
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.dist import get_rank, get_world_size
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
@ -516,7 +508,10 @@ def train_one_epoch(
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.encoder_projector.train()
|
||||
model.train()
|
||||
model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
@ -533,6 +528,9 @@ def train_one_epoch(
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
@ -648,7 +646,6 @@ def run(rank, world_size, args):
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
for name, param in speech_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
speech_encoder.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
if params.use_flash_attn:
|
||||
@ -671,7 +668,6 @@ def run(rank, world_size, args):
|
||||
if not params.unfreeze_llm:
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
llm.eval()
|
||||
else:
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
@ -728,7 +724,7 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
assert params.deepspeed and world_size > 1
|
||||
assert params.deepspeed
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=params, model=model, model_parameters=model.parameters()
|
||||
@ -865,6 +861,7 @@ def main():
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user