From f81363d3243731653cc94dfc303f04d1b09e31ac Mon Sep 17 00:00:00 2001 From: root Date: Thu, 15 May 2025 14:16:51 +0000 Subject: [PATCH] add speech continuation pretraining --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 75 ++++- .../SPEECH2SPEECH/qwen_omni/data_module.py | 306 ++++++++++++++---- .../SPEECH2SPEECH/qwen_omni/requirements.txt | 1 + .../SPEECH2SPEECH/qwen_omni/train.py | 104 +++--- 4 files changed, 391 insertions(+), 95 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index c974ee88f..fd8070691 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -122,7 +122,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "stage 1: Compute fbank feature from huggingface" + log "stage 6: Compute fbank feature from huggingface" # CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \ # --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ # --out-dir data/fbank_voice_assistant \ @@ -161,10 +161,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then --subset test --split test \ --audio-key audio --text-key text \ --prefix gigaspeech -fi -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "stage 9: Compute fbank feature from huggingface" CUDA_VISIBLE_DEVICES=0 python3 local/compute_whisper_fbank.py \ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb True \ --out-dir data/fbank_gigaspeech \ @@ -195,7 +192,7 @@ fi if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "stage 11: Decoding EN, only support batch_size=1 for now." + log "stage 11: Decoding EN, val set only support batch_size=1 for now." exp_dir=./qwen_omni/exp_speech2speech_en_continue # cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd - python3 ./qwen_omni/decode.py \ @@ -256,3 +253,71 @@ if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then --output-dir test_result done fi + + +if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then + log "stage 15: Training Speech2Speech Model, adaptor only" + exp_dir=./qwen_omni/exp_speech2text + ngpu=2 + torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ + --max-duration 600 \ + --enable-musan False \ + --audio-key audio --text-key continuation \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ + --on-the-fly-feats True \ + --deepspeed \ + --deepspeed_config ./qwen_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --dataset-format speech_continuation \ + --start-epoch 2 --pretrained-model-path $exp_dir/epoch-1/pytorch_model.bin \ + --use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False +fi + +if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then + log "stage 16: Training Speech2Speech Model, adaptor only" + exp_dir=./qwen_omni/exp_speech2text + ngpu=4 + + latest_checkpoint_step=-1 + # Check if exp_dir exists and is a directory + if [ -d "$exp_dir" ]; then + # List directories matching checkpoint-* and find the one with the largest step number + for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do + checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000 + # Extract step number using parameter expansion + current_step=${checkpoint_name#checkpoint-} + # Ensure current_step is a number + if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then + latest_checkpoint_step=$current_step + fi + done + fi + + train_cmd_args="--max-duration 1200 \ + --enable-musan False \ + --audio-key audio --text-key continuation \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ + --on-the-fly-feats True \ + --deepspeed \ + --huggingface-dataset-path-or-name /lustre/fsw/general_sa/yuekaiz/s2s \ + --deepspeed_config ./qwen_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --dataset-format speech_continuation \ + --use-lora False --unfreeze-llm False --unfreeze-speech-projector True --enable-speech-output False" + + if [ "$latest_checkpoint_step" -ge 0 ]; then + log "Continuing training from checkpoint-$latest_checkpoint_step" + step=$latest_checkpoint_step + train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt" + else + log "Starting training from scratch as no checkpoint was found in $exp_dir" + # No pretrained model or sampler state dict needed for the first run + fi + + torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ + $train_cmd_args +fi diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py index bc75bccd6..1f35f9b84 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py @@ -24,7 +24,7 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -from datasets import load_dataset +from datasets import interleave_datasets, load_dataset from lhotse import ( CutSet, WhisperFbank, @@ -36,6 +36,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, + PerturbSpeed, PrecomputedFeatures, SimpleCutSampler, SpecAugment, @@ -47,7 +48,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.utils import fix_random_seed from speech_dataset import K2SpeechRecognitionDataset from torch.utils.data import DataLoader -from utils import str2bool +from utils import get_rank, str2bool class _SeedWorkers: @@ -123,6 +124,14 @@ class AsrDataModule: "extraction. Will drop existing precomputed feature manifests " "if available.", ) + group.add_argument( + "--on-the-fly-speed-perturb", + type=str2bool, + default=True, + help="When enabled, use on-the-fly speed perturbation. " + "Will drop existing precomputed feature manifests " + "if available.", + ) group.add_argument( "--shuffle", type=str2bool, @@ -188,27 +197,27 @@ class AsrDataModule: group.add_argument( "--huggingface-dataset-path-or-name", type=str, - default="/workspace/Belle_1.4M-SLAM-Omni", + default=None, help="The path or name of the Huggingface dataset", ) group.add_argument( "--audio-key", type=str, - default="question_audio", + default="audio", help="The key in the Huggingface dataset containing the audio data", ) group.add_argument( "--text-key", type=str, - default="answer", + default="text", help="The key in the Huggingface dataset containing the text data", ) - group.add_argument( - "--resample-to-16kHz", - type=str2bool, - default=True, - help="Resample audio to 16kHz. Default: False.", - ) + # group.add_argument( + # "--resample-to-16kHz", + # type=str2bool, + # default=True, + # help="Resample audio to 16kHz. Default: False.", + # ) def train_dataloaders( self, @@ -232,6 +241,8 @@ class AsrDataModule: ) else: logging.info("Disable MUSAN") + if self.args.on_the_fly_speed_perturb and self.args.on_the_fly_feats: + transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms input_transforms = [] if self.args.enable_spec_aug: @@ -260,9 +271,11 @@ class AsrDataModule: logging.info("Disable SpecAugment") logging.info("About to create train dataset") + rank = get_rank() + train = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures( - WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) + WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}")) ) if self.args.on_the_fly_feats else eval(self.args.input_strategy)(), @@ -271,26 +284,6 @@ class AsrDataModule: return_cuts=self.args.return_cuts, ) - # if self.args.on_the_fly_feats: - # # NOTE: the PerturbSpeed transform should be added only if we - # # remove it from data prep stage. - # # Add on-the-fly speed perturbation; since originally it would - # # have increased epoch size by 3, we will apply prob 2/3 and use - # # 3x more epochs. - # # Speed perturbation probably should come first before - # # concatenation, but in principle the transforms order doesn't have - # # to be strict (e.g. could be randomized) - # # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # # Drop feats to be on the safe side. - # train = K2SpeechRecognitionDataset( - # cut_transforms=transforms, - # input_strategy=OnTheFlyFeatures( - # WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) - # ), - # input_transforms=input_transforms, - # return_cuts=self.args.return_cuts, - # ) - if self.args.bucketing_sampler: logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( @@ -298,8 +291,7 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 1000, drop_last=self.args.drop_last, ) else: @@ -339,10 +331,10 @@ class AsrDataModule: CutSet for validation. """ logging.info("About to create dev dataset") - + rank = get_rank() validate = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures( - WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) + WhisperFbank(WhisperFbankConfig(num_filters=80, device=f"cuda:{rank}")) ) if self.args.on_the_fly_feats else eval(self.args.input_strategy)(), @@ -470,25 +462,231 @@ class AsrDataModule: ) return {"test": VoiceAssistant_cuts} - # def train_cuts_en_vocalnet(self) -> CutSet: + @lru_cache() + def train_cuts_ultravox(self) -> CutSet: + logging.info("About to get train cuts") + if self.args.huggingface_dataset_path_or_name is not None: + librispeech_path = ( + self.args.huggingface_dataset_path_or_name + "/librispeech_asr" + ) + people_speech_path = ( + self.args.huggingface_dataset_path_or_name + "/peoples_speech" + ) + gigaspeech_path = self.args.huggingface_dataset_path_or_name + "/gigaspeech" + else: + librispeech_path = "fixie-ai/librispeech_asr" + people_speech_path = "fixie-ai/peoples_speech" + gigaspeech_path = "fixie-ai/gigaspeech" + # 148_688 + librispeech_other = load_dataset( + librispeech_path, "other", split="train.500", streaming=True + ) + # 104_014 + librispeech_clean_360 = load_dataset( + librispeech_path, "clean", split="train.360", streaming=True + ) + # 28_539 + librispeech_clean_100 = load_dataset( + librispeech_path, "clean", split="train.100", streaming=True + ) + + # 1_501_271 + people_speech_clean = load_dataset( + people_speech_path, "clean", split="train", streaming=True + ) + # 548_000 + people_speech_dirty_sa = load_dataset( + people_speech_path, "dirty_sa", split="train", streaming=True + ) + + # 8_266_422 + + gigaspeech = load_dataset( + gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True + ) + + librispeech_clean_100_cuts = CutSet.from_huggingface_dataset( + librispeech_clean_100, + audio_key=self.args.audio_key, + text_key=self.args.text_key, + ) + + librispeech_other_cuts = CutSet.from_huggingface_dataset( + librispeech_other, + audio_key=self.args.audio_key, + text_key=self.args.text_key, + ) + + librispeech_clean_360_cuts = CutSet.from_huggingface_dataset( + librispeech_clean_360, + audio_key=self.args.audio_key, + text_key=self.args.text_key, + ) + + gigaspeech_cuts = CutSet.from_huggingface_dataset( + gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key + ) + + people_speech_clean_cuts = CutSet.from_huggingface_dataset( + people_speech_clean, + audio_key=self.args.audio_key, + text_key=self.args.text_key, + ) + + people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset( + people_speech_dirty_sa, + audio_key=self.args.audio_key, + text_key=self.args.text_key, + ) + + return CutSet.mux( + librispeech_clean_100_cuts, + librispeech_clean_360_cuts, + librispeech_other_cuts, + gigaspeech_cuts, + people_speech_clean_cuts, + people_speech_dirty_sa_cuts, + weights=[ + 28539, + 104014, + 148688, + 8266422, + 1501271, + 548000, + ], + ) + + # @lru_cache() + # def train_cuts_ultravox(self) -> CutSet: # logging.info("About to get train cuts") - # VoiceAssistant_cuts = load_manifest_lazy( - # self.args.manifest_dir / "cuts_debug.jsonl.gz" - # ) - # return VoiceAssistant_cuts + # keep_columns = ["audio", "text", "continuation", "id"] + # librispeech_path="fixie-ai/librispeech_asr" + # # 148_688 + # librispeech_other = load_dataset(librispeech_path, 'other', split='train.500', streaming=True) + # # 104_014 + # librispeech_clean_360 = load_dataset(librispeech_path, 'clean', split='train.360', streaming=True) + # # 28_539 + # librispeech_clean_100 = load_dataset(librispeech_path, 'clean', split='train.100', streaming=True) - # @lru_cache() - # def valid_cuts_en_vocalnet(self) -> CutSet: - # logging.info("About to get valid cuts") - # VoiceAssistant_cuts = load_manifest_lazy( - # self.args.manifest_dir / "cuts_debug.jsonl.gz" - # ) - # return VoiceAssistant_cuts + # cols_to_remove = librispeech_clean_100.column_names + # cols_to_remove = [col for col in cols_to_remove if col not in keep_columns] + # librispeech_clean_100 = librispeech_clean_100.remove_columns(cols_to_remove) + # librispeech_clean_360 = librispeech_clean_360.remove_columns(cols_to_remove) + # librispeech_other = librispeech_other.remove_columns(cols_to_remove) + # people_speech_path="fixie-ai/peoples_speech" + # # 1_501_271 + # people_speech_clean = load_dataset(people_speech_path, 'clean', split='train', streaming=True) + # # 548_000 + # people_speech_dirty_sa = load_dataset(people_speech_path, 'dirty_sa', split='train', streaming=True) + # cols_to_remove = people_speech_clean.column_names + # cols_to_remove = [col for col in cols_to_remove if col not in keep_columns] + # people_speech_clean = people_speech_clean.remove_columns(cols_to_remove) + # people_speech_dirty_sa = people_speech_dirty_sa.remove_columns(cols_to_remove) - # @lru_cache() - # def test_cuts_en_vocalnet(self) -> CutSet: - # logging.info("About to get test cuts") - # VoiceAssistant_cuts = load_manifest_lazy( - # self.args.manifest_dir / "cuts_debug.jsonl.gz" + # # 8_266_422 + # gigaspeech_path="fixie-ai/gigaspeech" + # gigaspeech = load_dataset(gigaspeech_path, 'xl-empty-audio-removed', split='train', streaming=True) + # # first rename segment_id to id + # gigaspeech = gigaspeech.rename_column("segment_id", "id") + # cols_to_remove = gigaspeech.column_names + # cols_to_remove = [col for col in cols_to_remove if col not in keep_columns] + # gigaspeech = gigaspeech.remove_columns(cols_to_remove) + + # total_item = 104014 + 28539 + 8266422 + 1501271 + 548000 + 148688 + # final_datasets = interleave_datasets([ + # librispeech_clean_100, + # librispeech_clean_360, + # gigaspeech, + # people_speech_clean, + # people_speech_dirty_sa, + # librispeech_other, + # ], probabilities=[ + # 28539 / total_item, + # 104014 / total_item, + # 8266422 / total_item, + # 1501271 / total_item, + # 548000 / total_item, + # 148688 / total_item, + # ]) + + # train_cuts = CutSet.from_huggingface_dataset( + # final_datasets, audio_key=self.args.audio_key, text_key=self.args.text_key # ) - # return VoiceAssistant_cuts + + # return train_cuts + + @lru_cache() + def valid_cuts_ultravox(self) -> CutSet: + logging.info("About to get valid cuts") + librispeech_path = "fixie-ai/librispeech_asr" + librispeech_clean_valid = load_dataset( + librispeech_path, "clean", split="validation", streaming=True + ) + librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset( + librispeech_clean_valid, + audio_key=self.args.audio_key, + text_key=self.args.text_key, + ) + return librispeech_clean_valid_cuts + + @lru_cache() + def train_cuts_librispeech(self) -> CutSet: + logging.info("About to get train cuts") + + # librispeech_path="fixie-ai/librispeech_asr" + librispeech_path = "/workspace/slam/librispeech_asr" + # 148_688 + librispeech_other = load_dataset( + librispeech_path, "other", split="train.500", streaming=True + ) + # 104_014 + librispeech_clean_360 = load_dataset( + librispeech_path, "clean", split="train.360", streaming=True + ) + # 28_539 + librispeech_clean_100 = load_dataset( + librispeech_path, "clean", split="train.100", streaming=True + ) + + librispeech_clean_100_cuts = CutSet.from_huggingface_dataset( + librispeech_clean_100, + audio_key=self.args.audio_key, + text_key=self.args.text_key, + ) + + librispeech_other_cuts = CutSet.from_huggingface_dataset( + librispeech_other, + audio_key=self.args.audio_key, + text_key=self.args.text_key, + ) + + librispeech_clean_360_cuts = CutSet.from_huggingface_dataset( + librispeech_clean_360, + audio_key=self.args.audio_key, + text_key=self.args.text_key, + ) + + return CutSet.mux( + librispeech_clean_100_cuts, + librispeech_clean_360_cuts, + librispeech_other_cuts, + weights=[ + 28539, + 104014, + 148688, + ], + ) + + @lru_cache() + def train_cuts_gigaspeech(self) -> CutSet: + logging.info("About to get train cuts") + gigaspeech_path = "fixie-ai/gigaspeech" + gigaspeech = load_dataset( + gigaspeech_path, "xl-empty-audio-removed", split="train", streaming=True + ) + + gigaspeech_cuts = CutSet.from_huggingface_dataset( + gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key + ) + + return gigaspeech_cuts diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements.txt b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements.txt index 2db53f3ff..573e8232d 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements.txt +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/requirements.txt @@ -10,3 +10,4 @@ transformers>=4.37.0 flash-attn peft torchmetrics +triton==3.3.0 # may be violate with openai-whisper diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index d23d578c6..1ed0204db 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -68,24 +68,26 @@ from transformers import ( Qwen2Config, Qwen2ForCausalLM, ) -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -# from icefall import diagnostics -from utils import get_rank, get_world_size # from icefall.env import get_env_info +# from icefall import diagnostics from utils import ( # filter_uneven_sized_batch, AttributeDict, MetricsTracker, + get_rank, + get_world_size, setup_logger, str2bool, ) +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward DEFAULT_SPEECH_TOKEN = "" try: - torch.multiprocessing.set_start_method('spawn') + torch.multiprocessing.set_start_method("spawn") except RuntimeError: pass + def set_batch_count(model: nn.Module, batch_count: float) -> None: for module in model.modules(): if hasattr(module, "batch_count"): @@ -272,7 +274,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 5000, + "valid_interval": 3000, # "env_info": get_env_info(), } ) @@ -332,6 +334,21 @@ def process_batch_vocalnet(batch: dict): return messages, answer_cosyvoice_speech_token +def process_batch_speech_continuation(batch: dict): + messages = [] + for i in range(len(batch["supervisions"]["text"])): + message = [ + { + "role": "user", + "content": f"Continue the following text using less than 50 words:\n\n{DEFAULT_SPEECH_TOKEN}", + }, + {"role": "assistant", "content": batch["supervisions"]["text"][i]}, + ] + # transcript = batch["supervisions"]["cut"][i].custom["text"] + messages.append(message) + return messages + + def compute_loss( params: AttributeDict, tokenizer: AutoTokenizer, @@ -429,13 +446,13 @@ def compute_loss( feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) - batch_idx_train = params.batch_idx_train - # WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet if params.dataset_format == "slam_omni": messages, answer_cosyvoice_speech_token = process_batch_slam_omni(batch) elif params.dataset_format == "vocalnet": messages, answer_cosyvoice_speech_token = process_batch_vocalnet(batch) + elif params.dataset_format == "speech_continuation": + messages = process_batch_speech_continuation(batch) else: raise ValueError(f"Unknown dataset format: {params.dataset_format}") @@ -566,8 +583,11 @@ def train_one_epoch( The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ - model.encoder_projector.train() - + # model.encoder_projector.train() + model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): @@ -583,6 +603,9 @@ def train_one_epoch( world_size=world_size, ) model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -594,7 +617,7 @@ def train_one_epoch( if batch_idx != 0: model.save_checkpoint( save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + tag=f"zero-checkpoint-{params.batch_idx_train}", client_state={}, exclude_frozen_parameters=True, ) @@ -602,18 +625,18 @@ def train_one_epoch( if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", - tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + f"{params.exp_dir}/checkpoint-{params.batch_idx_train}", + tag=f"zero-checkpoint-{params.batch_idx_train}", exclude_frozen_parameters=True, ) # save sampler state dict into checkpoint sampler_state_dict = train_dl.sampler.state_dict() torch.save( sampler_state_dict, - f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt", + f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt", ) os.system( - f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" + f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}" ) try: with torch.amp.autocast("cuda", enabled=params.use_fp16): @@ -687,9 +710,9 @@ def run(rank, world_size, args): fix_random_seed(params.seed) - setup_logger(f"{params.exp_dir}/log/log-train") + if rank == 0: + setup_logger(f"{params.exp_dir}/log/log-train") logging.info(params) - logging.info("About to create model") replace_whisper_encoder_forward() @@ -698,7 +721,6 @@ def run(rank, world_size, args): speech_encoder_dim = whisper_model.dims.n_audio_state for name, param in speech_encoder.named_parameters(): param.requires_grad = False - speech_encoder.eval() tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) @@ -721,7 +743,7 @@ def run(rank, world_size, args): if not params.unfreeze_llm: for name, param in llm.named_parameters(): param.requires_grad = False - llm.eval() + else: if params.use_lora: lora_config = LoraConfig( @@ -809,6 +831,9 @@ def run(rank, world_size, args): if params.pretrained_model_path: checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + # set params.batch_idx_train according to the checkpoint name + if "checkpoint-" in params.pretrained_model_path: + params.batch_idx_train = int(params.pretrained_model_path.split("-")[-1]) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -842,21 +867,22 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 1.0 or c.duration > 30.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - codec_len = ( - len(c.custom["answer_cosyvoice_speech_token"]) - if "answer_cosyvoice_speech_token" in c.custom - else len(c.custom["speech_token"]) - ) - if codec_len > 2200: + if c.duration < 1.0 or c.duration > 29.5: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False + if "speech_token" in c.custom or "answer_cosyvoice_speech_token" in c.custom: + codec_len = ( + len(c.custom["answer_cosyvoice_speech_token"]) + if "answer_cosyvoice_speech_token" in c.custom + else len(c.custom["speech_token"]) + ) + if codec_len > 2200: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}" + ) + return False return True if params.dataset_format == "slam_omni": @@ -865,6 +891,11 @@ def run(rank, world_size, args): elif params.dataset_format == "vocalnet": train_cuts = data_module.train_cuts_en_vocalnet() valid_cuts = data_module.valid_cuts_en_vocalnet() + elif params.dataset_format == "speech_continuation": + # train_cuts = data_module.train_cuts_ultravox() + # train_cuts = data_module.train_cuts_gigaspeech() + train_cuts = data_module.train_cuts_librispeech() + valid_cuts = data_module.valid_cuts_ultravox() else: raise ValueError(f"Unknown dataset format: {params.dataset_format}") @@ -879,7 +910,7 @@ def run(rank, world_size, args): train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - + # train_dl = data_module.valid_dataloaders(train_cuts) valid_dl = data_module.valid_dataloaders(valid_cuts) if args.tensorboard and rank == 0: @@ -913,25 +944,25 @@ def run(rank, world_size, args): model.save_checkpoint( save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", client_state={}, exclude_frozen_parameters=True, ) if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}", + f"{params.exp_dir}/epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", exclude_frozen_parameters=True, ) # save sampler state dict into checkpoint sampler_state_dict = train_dl.sampler.state_dict() torch.save( sampler_state_dict, - f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", + f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt", ) - os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") + os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}") logging.info("Done!") @@ -971,6 +1002,7 @@ def main(): torch.set_num_threads(1) torch.set_num_interop_threads(1) + warnings.filterwarnings("ignore", category=FutureWarning) run(rank=rank, world_size=world_size, args=args)