refactor train to reuse code

This commit is contained in:
root 2025-05-26 19:53:16 -07:00
parent e6e1f3fa4f
commit 39700d5c94
3 changed files with 632 additions and 84 deletions

View File

@ -59,9 +59,9 @@ class SPEECH_LLM(nn.Module):
def __init__(
self,
encoder: nn.Module,
llm: nn.Module,
encoder_projector: nn.Module,
encoder: nn.Module = None,
llm: nn.Module = None,
encoder_projector: nn.Module = None,
codec_lm: nn.Module = None,
codec_lm_padding_side: str = "left",
teacher_llm: nn.Module = None,
@ -330,20 +330,19 @@ class SPEECH_LLM(nn.Module):
labels: torch.LongTensor = None,
speech_codec_ids: torch.LongTensor = None,
):
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
if fbank is not None:
encoder_outs = self.encoder(fbank)
speech_features = self.encoder_projector(encoder_outs)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
(
inputs_embeds,
attention_mask,
labels,
_,
) = self._merge_input_ids_with_speech_features(
speech_features, inputs_embeds, input_ids, attention_mask, labels
)
input_seq_len = attention_mask.sum(dim=1) # shape, B
(
text_label_start_index_list,

View File

@ -69,8 +69,6 @@ from transformers import (
Qwen2ForCausalLM,
)
# from icefall.env import get_env_info
# from icefall import diagnostics
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
@ -137,6 +135,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to enable speech codec output.",
)
parser.add_argument(
"--enable-speech-input",
type=str2bool,
default=True,
help="Whether to enable speech fbank input.",
)
parser.add_argument(
"--speech-tokenizer-type",
type=str,
@ -145,11 +150,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
def add_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--tensorboard",
type=str2bool,
@ -243,6 +244,12 @@ def get_parser():
help="The name of the dataset.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--loss-type",
type=str,
@ -252,7 +259,7 @@ def get_parser():
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
add_training_arguments(parser)
return parser
@ -532,7 +539,6 @@ def compute_loss(
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
# WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet
messages, answer_cosyvoice_speech_token = extract_text_and_speech_token(
batch, params.enable_speech_output
)
@ -686,9 +692,9 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
# model.encoder_projector.train()
model.train()
model.encoder.eval()
if params.enable_speech_input:
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker()
@ -706,7 +712,8 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
model.encoder.eval()
if params.enable_speech_input:
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
@ -796,36 +803,11 @@ def train_one_epoch(
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if rank == 0:
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
def get_model(params):
"""Load and prepare the speech-to-speech model."""
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
@ -842,11 +824,9 @@ def run(rank, world_size, args):
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if not params.unfreeze_llm:
for name, param in llm.named_parameters():
param.requires_grad = False
else:
if params.use_lora:
lora_config = LoraConfig(
@ -867,21 +847,29 @@ def run(rank, world_size, args):
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
if not params.unfreeze_speech_projector:
for name, param in encoder_projector.named_parameters():
if params.enable_speech_input:
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
encoder_projector.eval()
encoder_projector = EncoderProjector(
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
)
if not params.unfreeze_speech_projector:
for name, param in encoder_projector.named_parameters():
param.requires_grad = False
encoder_projector.eval()
else:
speech_encoder = None
encoder_projector = None
if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn
@ -922,17 +910,6 @@ def run(rank, world_size, args):
codec_lm.config.mask_token_id = codec_vocab_size - 4
else:
codec_lm = None
if params.loss_type == "kl_div":
teacher_llm = AutoModelForCausalLM.from_pretrained(
params.llm_path_or_name,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
for name, param in teacher_llm.named_parameters():
param.requires_grad = False
teacher_llm.eval()
else:
teacher_llm = None
model = SPEECH_LLM(
speech_encoder,
@ -940,9 +917,7 @@ def run(rank, world_size, args):
encoder_projector,
codec_lm,
codec_lm_padding_side="left" if params.use_flash_attn else "right",
teacher_llm=teacher_llm,
)
if params.pretrained_model_path or params.last_stage_model_path:
if params.pretrained_model_path is None:
checkpoint = torch.load(params.last_stage_model_path, map_location="cpu")
@ -963,6 +938,32 @@ def run(rank, world_size, args):
if param.requires_grad:
logging.info(f"{name}: {param.shape}")
return model, tokenizer
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if rank == 0:
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
model, tokenizer = get_model(params)
if torch.cuda.is_available():
device = torch.device("cuda", get_local_rank())
else:
@ -1032,9 +1033,6 @@ def run(rank, world_size, args):
elif params.dataset == "gigaspeech":
train_cuts = data_module.train_cuts_gigaspeech()
valid_cuts = data_module.valid_cuts_ultravox()
elif params.dataset == "emilia_en":
train_cuts = data_module.train_cuts_emilia_en()
valid_cuts = data_module.valid_cuts_emilia_en()
else:
raise ValueError(f"Unknown dataset: {params.dataset}")
@ -1049,7 +1047,6 @@ def run(rank, world_size, args):
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
# train_dl = data_module.valid_dataloaders(train_cuts)
valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0:

View File

@ -0,0 +1,552 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
"""
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2ForCausalLM,
)
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader
from train import add_model_arguments, add_training_arguments, get_params, compute_validation_loss, get_model, display_and_save_batch
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
get_local_rank,
get_rank,
get_world_size,
setup_logger,
str2bool,
)
DEFAULT_SPEECH_TOKEN = "<speech>"
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# parser.add_argument(
# "--loss-type",
# type=str,
# default="ce",
# help="The type of loss to use.",
# )
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
add_training_arguments(parser)
return parser
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
# WAR: TODO FIXME check qwen3
# THIS IS THE ONLY DIFFERENCE FROM preprocess
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
target_ids[row, col] = default_speech_token_id
# remove default_speech_token_id from target_ids and input_ids
batch_size = target_ids.size(0)
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def data_collator(batch, tokenizer, cut_off_len=2048):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch):
speech_tokens.append(item["code"])
message_list_item = []
message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": item["text"]},
]
messages.append(message_list_item)
durations.append(item["duration"])
ids.append(item["id"])
lang.append(item["language"])
dnsmos.append(item["dnsmos"])
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
return {
"speech_tokens": speech_tokens,
"input_ids": input_ids,
"attention_mask": attention_mask,
"target_ids": target_ids,
"durations": durations,
"ids": ids,
"lang": lang,
"dnsmos": dnsmos,
}
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute the loss for the given batch.
Args:
params:
It is returned by :func:`get_params`.
tokenizer:
The tokenizer used to encode the text.
model:
The model for training.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
Whether it is training.
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
device = next(model.parameters()).device
input_ids, attention_mask, target_ids, answer_cosyvoice_speech_token = batch["input_ids"], batch["attention_mask"], batch["target_ids"], batch["speech_tokens"]
with torch.set_grad_enabled(is_training):
(
text_loss,
acc,
codec_loss,
codec_acc,
codec_topk_acc,
) = model.forward_with_speech_output(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
speech_codec_ids=answer_cosyvoice_speech_token,
)
loss = text_loss + codec_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
feature_lens = batch["supervisions"]["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["acc"] = (
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
info["codec_acc"] = codec_acc * info["frames"]
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item()
return loss, info
def train_one_epoch(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
tokenizer=tokenizer,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx != 0:
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"zero-checkpoint-{params.batch_idx_train}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
tag=f"zero-checkpoint-{params.batch_idx_train}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
# sampler_state_dict = train_dl.sampler.state_dict()
sampler_state_dict = train_dl.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
)
try:
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# deepspeed's backward() is different from torch's backward()
# in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss)
model.step()
except: # noqa
display_and_save_batch(batch, params=params)
raise
if batch_idx % params.log_interval == 0:
try:
cur_lr = scheduler.get_last_lr()[0]
except: # noqa
cur_lr = 0.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if rank == 0:
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
model, tokenizer = get_model(params)
if torch.cuda.is_available():
device = torch.device("cuda", get_local_rank())
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device)
assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()
)
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
ds = load_dataset(data_path, split="train")
train_test_split = dataset.train_test_split(test_size=1000, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dl = StatefulDataLoader(
train_dataset,
batch_size=2,
sampler=sampler,
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=lambda features: data_collator(
features, tokenizer
),
)
train_dl.load_state_dict(sampler_state_dict)
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
valid_dl = DataLoader(
eval_dataset,
batch_size=2,
sampler=valid_sampler,
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=lambda features: data_collator(
features
),
)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
tokenizer=tokenizer,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"zero-epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}",
tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
# sampler_state_dict = train_dl.sampler.state_dict()
sampler_state_dict = train_dl.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!")
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = get_world_size()
rank = get_rank()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__":
main()