diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py index 97484486d..5ba3c1a7c 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py @@ -59,9 +59,9 @@ class SPEECH_LLM(nn.Module): def __init__( self, - encoder: nn.Module, - llm: nn.Module, - encoder_projector: nn.Module, + encoder: nn.Module = None, + llm: nn.Module = None, + encoder_projector: nn.Module = None, codec_lm: nn.Module = None, codec_lm_padding_side: str = "left", teacher_llm: nn.Module = None, @@ -330,20 +330,19 @@ class SPEECH_LLM(nn.Module): labels: torch.LongTensor = None, speech_codec_ids: torch.LongTensor = None, ): - encoder_outs = self.encoder(fbank) - - speech_features = self.encoder_projector(encoder_outs) - inputs_embeds = self.llm.get_input_embeddings()(input_ids) + if fbank is not None: + encoder_outs = self.encoder(fbank) + speech_features = self.encoder_projector(encoder_outs) + ( + inputs_embeds, + attention_mask, + labels, + _, + ) = self._merge_input_ids_with_speech_features( + speech_features, inputs_embeds, input_ids, attention_mask, labels + ) - ( - inputs_embeds, - attention_mask, - labels, - _, - ) = self._merge_input_ids_with_speech_features( - speech_features, inputs_embeds, input_ids, attention_mask, labels - ) input_seq_len = attention_mask.sum(dim=1) # shape, B ( text_label_start_index_list, diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index 87b8315f1..f1b25d3e6 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -69,8 +69,6 @@ from transformers import ( Qwen2ForCausalLM, ) -# from icefall.env import get_env_info -# from icefall import diagnostics from utils import ( # filter_uneven_sized_batch, AttributeDict, MetricsTracker, @@ -137,6 +135,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Whether to enable speech codec output.", ) + parser.add_argument( + "--enable-speech-input", + type=str2bool, + default=True, + help="Whether to enable speech fbank input.", + ) + parser.add_argument( "--speech-tokenizer-type", type=str, @@ -145,11 +150,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - +def add_training_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--tensorboard", type=str2bool, @@ -243,6 +244,12 @@ def get_parser(): help="The name of the dataset.", ) + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( "--loss-type", type=str, @@ -252,7 +259,7 @@ def get_parser(): parser = deepspeed.add_config_arguments(parser) add_model_arguments(parser) - + add_training_arguments(parser) return parser @@ -532,7 +539,6 @@ def compute_loss( feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) - # WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet messages, answer_cosyvoice_speech_token = extract_text_and_speech_token( batch, params.enable_speech_output ) @@ -686,9 +692,9 @@ def train_one_epoch( The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ - # model.encoder_projector.train() model.train() - model.encoder.eval() + if params.enable_speech_input: + model.encoder.eval() if not params.unfreeze_llm: model.llm.eval() tot_loss = MetricsTracker() @@ -706,7 +712,8 @@ def train_one_epoch( world_size=world_size, ) model.train() - model.encoder.eval() + if params.enable_speech_input: + model.encoder.eval() if not params.unfreeze_llm: model.llm.eval() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") @@ -796,36 +803,11 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - - if rank == 0: - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info(params) - logging.info("About to create model") - - replace_whisper_encoder_forward() - whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") - speech_encoder = whisper_model.encoder - speech_encoder_dim = whisper_model.dims.n_audio_state - for name, param in speech_encoder.named_parameters(): - param.requires_grad = False - +def get_model(params): + """Load and prepare the speech-to-speech model.""" tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) + special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} + tokenizer.add_special_tokens(special_tokens_dict) if params.use_flash_attn: attn_implementation = "flash_attention_2" @@ -842,11 +824,9 @@ def run(rank, world_size, args): attn_implementation=attn_implementation, torch_dtype=torch_dtype, ) - if not params.unfreeze_llm: for name, param in llm.named_parameters(): param.requires_grad = False - else: if params.use_lora: lora_config = LoraConfig( @@ -867,21 +847,29 @@ def run(rank, world_size, args): llm = get_peft_model(llm, lora_config) llm.print_trainable_parameters() - special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} - tokenizer.add_special_tokens(special_tokens_dict) - llm.config.pad_token_id = tokenizer.pad_token_id llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( DEFAULT_SPEECH_TOKEN ) - encoder_projector = EncoderProjector( - speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate - ) - if not params.unfreeze_speech_projector: - for name, param in encoder_projector.named_parameters(): + if params.enable_speech_input: + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() + whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") + speech_encoder = whisper_model.encoder + speech_encoder_dim = whisper_model.dims.n_audio_state + for name, param in speech_encoder.named_parameters(): param.requires_grad = False - encoder_projector.eval() + encoder_projector = EncoderProjector( + speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate + ) + if not params.unfreeze_speech_projector: + for name, param in encoder_projector.named_parameters(): + param.requires_grad = False + encoder_projector.eval() + else: + speech_encoder = None + encoder_projector = None if params.enable_speech_output: # Determine attn_implementation and torch_dtype based on use_flash_attn @@ -922,17 +910,6 @@ def run(rank, world_size, args): codec_lm.config.mask_token_id = codec_vocab_size - 4 else: codec_lm = None - if params.loss_type == "kl_div": - teacher_llm = AutoModelForCausalLM.from_pretrained( - params.llm_path_or_name, - attn_implementation=attn_implementation, - torch_dtype=torch_dtype, - ) - for name, param in teacher_llm.named_parameters(): - param.requires_grad = False - teacher_llm.eval() - else: - teacher_llm = None model = SPEECH_LLM( speech_encoder, @@ -940,9 +917,7 @@ def run(rank, world_size, args): encoder_projector, codec_lm, codec_lm_padding_side="left" if params.use_flash_attn else "right", - teacher_llm=teacher_llm, ) - if params.pretrained_model_path or params.last_stage_model_path: if params.pretrained_model_path is None: checkpoint = torch.load(params.last_stage_model_path, map_location="cpu") @@ -963,6 +938,32 @@ def run(rank, world_size, args): if param.requires_grad: logging.info(f"{name}: {param.shape}") + return model, tokenizer + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + if rank == 0: + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info(params) + logging.info("About to create model") + + model, tokenizer = get_model(params) + if torch.cuda.is_available(): device = torch.device("cuda", get_local_rank()) else: @@ -1032,9 +1033,6 @@ def run(rank, world_size, args): elif params.dataset == "gigaspeech": train_cuts = data_module.train_cuts_gigaspeech() valid_cuts = data_module.valid_cuts_ultravox() - elif params.dataset == "emilia_en": - train_cuts = data_module.train_cuts_emilia_en() - valid_cuts = data_module.valid_cuts_emilia_en() else: raise ValueError(f"Unknown dataset: {params.dataset}") @@ -1049,7 +1047,6 @@ def run(rank, world_size, args): train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - # train_dl = data_module.valid_dataloaders(train_cuts) valid_dl = data_module.valid_dataloaders(valid_cuts) if args.tensorboard and rank == 0: diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py new file mode 100755 index 000000000..8fd6609a4 --- /dev/null +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train_tts.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model. +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper +# Qwen Pretrained model +huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct + +torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ + --max-duration 50 \ + --enable-musan False \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./qwen_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True +""" + +import argparse +import copy +import logging +import os +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import deepspeed +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import transformers + +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from label_smoothing import LabelSmoothingLoss + +from lhotse.utils import fix_random_seed +from model import IGNORE_TOKEN_ID, SPEECH_LLM +from peft import LoraConfig, get_peft_model +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + Qwen2Config, + Qwen2ForCausalLM, +) +from torchdata.stateful_dataloader import StatefulDataLoader +from torch.utils.data import DistributedSampler, DataLoader + +from train import add_model_arguments, add_training_arguments, get_params, compute_validation_loss, get_model, display_and_save_batch +from utils import ( # filter_uneven_sized_batch, + AttributeDict, + MetricsTracker, + get_local_rank, + get_rank, + get_world_size, + setup_logger, + str2bool, +) + +DEFAULT_SPEECH_TOKEN = "" +try: + torch.multiprocessing.set_start_method("spawn") +except RuntimeError: + pass + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # parser.add_argument( + # "--loss-type", + # type=str, + # default="ce", + # help="The type of loss to use.", + # ) + + parser = deepspeed.add_config_arguments(parser) + add_model_arguments(parser) + add_training_arguments(parser) + return parser + +def preprocess( + messages, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """Preprocesses the data for supervised fine-tuning.""" + texts = [] + TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + for i, msg in enumerate(messages): + texts.append( + tokenizer.apply_chat_template( + msg, + tokenize=True, + chat_template=TEMPLATE, + add_generation_prompt=False, + padding="longest", # FIX me change padding to longest + truncation=False, + ) + ) + if len(texts) != len(messages): + logging.warning(f"Remove too long text, {messages} ") + max_len_texts = max([len(text) for text in texts]) + if tokenizer.padding_side == "right": + texts = [ + text + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + for text in texts + ] + else: + texts = [ + [tokenizer.pad_token_id] * (max_len_texts - len(text)) + text + for text in texts + ] + input_ids = torch.tensor(texts, dtype=torch.int) + + target_ids = input_ids.clone() + target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID + # mask all tokens before token_id with IGNORE_TOKEN_ID + # first get the indices of the tokens + mask_prompt = True + if mask_prompt: + default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN) + mask_indices = torch.where(input_ids == default_speech_token_id) + for i in range(mask_indices[0].size(0)): + row = mask_indices[0][i] + col = mask_indices[1][i] + # + 2 to skip: 'assistant', '\n' + # WAR: TODO FIXME check qwen3 + # THIS IS THE ONLY DIFFERENCE FROM preprocess + target_ids[row, : col + 6] = IGNORE_TOKEN_ID + target_ids[row, col] = default_speech_token_id + # remove default_speech_token_id from target_ids and input_ids + batch_size = target_ids.size(0) + + target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1) + input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1) + + attention_mask = input_ids.ne(tokenizer.pad_token_id) + return input_ids, attention_mask, target_ids + +def data_collator(batch, tokenizer, cut_off_len=2048): + speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], [] + for i, item in enumerate(batch): + speech_tokens.append(item["code"]) + message_list_item = [] + message_list_item += [ + {"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"}, + {"role": "assistant", "content": item["text"]}, + ] + messages.append(message_list_item) + durations.append(item["duration"]) + ids.append(item["id"]) + lang.append(item["language"]) + dnsmos.append(item["dnsmos"]) + + input_ids, attention_mask, target_ids = preprocess(messages, tokenizer) + target_ids = target_ids.type(torch.LongTensor) + input_ids = input_ids.type(torch.LongTensor) + + return { + "speech_tokens": speech_tokens, + "input_ids": input_ids, + "attention_mask": attention_mask, + "target_ids": target_ids, + "durations": durations, + "ids": ids, + "lang": lang, + "dnsmos": dnsmos, + } + +def compute_loss( + params: AttributeDict, + tokenizer: AutoTokenizer, + model: nn.Module, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute the loss for the given batch. + Args: + params: + It is returned by :func:`get_params`. + tokenizer: + The tokenizer used to encode the text. + model: + The model for training. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + Whether it is training. + Returns: + Return a tuple of two elements. The first element is the loss tensor. + """ + device = next(model.parameters()).device + input_ids, attention_mask, target_ids, answer_cosyvoice_speech_token = batch["input_ids"], batch["attention_mask"], batch["target_ids"], batch["speech_tokens"] + + with torch.set_grad_enabled(is_training): + ( + text_loss, + acc, + codec_loss, + codec_acc, + codec_topk_acc, + ) = model.forward_with_speech_output( + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + labels=target_ids.to(device), + speech_codec_ids=answer_cosyvoice_speech_token, + ) + loss = text_loss + codec_loss + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + feature_lens = batch["supervisions"]["num_frames"] + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["acc"] = ( + acc * info["frames"] + ) # WAR: to avoid normalization by the number of frames + + info["codec_acc"] = codec_acc * info["frames"] + info["codec_topk_acc"] = codec_topk_acc * info["frames"] + info["codec_loss"] = codec_loss.detach().cpu().item() + info["text_loss"] = text_loss.detach().cpu().item() + return loss, info + + +def train_one_epoch( + params: AttributeDict, + tokenizer: AutoTokenizer, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + if batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + tokenizer=tokenizer, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + if batch_idx != 0: + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"zero-checkpoint-{params.batch_idx_train}", + client_state={}, + exclude_frozen_parameters=True, + ) + + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/checkpoint-{params.batch_idx_train}", + tag=f"zero-checkpoint-{params.batch_idx_train}", + exclude_frozen_parameters=True, + ) + # save sampler state dict into checkpoint + # sampler_state_dict = train_dl.sampler.state_dict() + sampler_state_dict = train_dl.state_dict() + torch.save( + sampler_state_dict, + f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt", + ) + os.system( + f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}" + ) + try: + with torch.amp.autocast("cuda", enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + # deepspeed's backward() is different from torch's backward() + # in that it does not accept a loss tensor as input. + # It computes the loss internally. + model.backward(loss) + model.step() + + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if batch_idx % params.log_interval == 0: + try: + cur_lr = scheduler.get_last_lr()[0] + except: # noqa + cur_lr = 0.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + if rank == 0: + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info(params) + logging.info("About to create model") + + model, tokenizer = get_model(params) + + if torch.cuda.is_available(): + device = torch.device("cuda", get_local_rank()) + else: + device = torch.device("cpu") + logging.info(f"Device: {device}") + model.to(device) + + assert params.deepspeed and world_size > 1 + logging.info("Using DeepSpeed") + model, optimizer, _, scheduler = deepspeed.initialize( + args=params, model=model, model_parameters=model.parameters() + ) + + sampler_state_dict = None + if params.sampler_state_dict_path: + sampler_state_dict = torch.load(params.sampler_state_dict_path) + + data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en" + ds = load_dataset(data_path, split="train") + train_test_split = dataset.train_test_split(test_size=1000, seed=42) + train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] + + sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) + train_dl = StatefulDataLoader( + train_dataset, + batch_size=2, + sampler=sampler, + shuffle=False, + num_workers=1, + prefetch_factor=1, + collate_fn=lambda features: data_collator( + features, tokenizer + ), + ) + train_dl.load_state_dict(sampler_state_dict) + valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank) + valid_dl = DataLoader( + eval_dataset, + batch_size=2, + sampler=valid_sampler, + shuffle=False, + num_workers=1, + prefetch_factor=1, + collate_fn=lambda features: data_collator( + features + ), + ) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + logging.info(f"start training from epoch {params.start_epoch}") + for epoch in range(params.start_epoch, params.num_epochs + 1): + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + tokenizer=tokenizer, + model=model, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"zero-epoch-{params.cur_epoch}", + client_state={}, + exclude_frozen_parameters=True, + ) + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", + exclude_frozen_parameters=True, + ) + # save sampler state dict into checkpoint + # sampler_state_dict = train_dl.sampler.state_dict() + sampler_state_dict = train_dl.state_dict() + torch.save( + sampler_state_dict, + f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt", + ) + + os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}") + + logging.info("Done!") + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = get_world_size() + rank = get_rank() + + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + warnings.filterwarnings("ignore", category=FutureWarning) + run(rank=rank, world_size=world_size, args=args) + + +if __name__ == "__main__": + main()