From b20a0d0e35f4d0c83146c29c14cdc473311a2781 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 8 May 2025 19:21:41 -0700 Subject: [PATCH] add on the fly feature --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 8 +- .../SPEECH2SPEECH/qwen_omni/data_module.py | 94 +++++++++++++------ .../SPEECH2SPEECH/qwen_omni/train.py | 5 +- 3 files changed, 71 insertions(+), 36 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index 58465c448..fcdfdd69f 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -174,13 +174,13 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then --prefix gigaspeech fi - -ngpu=2 +# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd - +ngpu=4 exp_dir=./qwen_omni/exp_speech2speech_en if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then log "stage 10: Training Speech2Speech Model" torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ - --max-duration 50 \ + --max-duration 150 \ --enable-musan False \ --exp-dir $exp_dir \ --speech-encoder-path-or-name models/large-v2.pt \ @@ -189,6 +189,6 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then --manifest-dir data/fbank \ --deepspeed \ --deepspeed_config ./qwen_omni/ds_config_zero1.json \ - --use-flash-attn True \ + --use-flash-attn True --on-the-fly-feats True \ --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True fi diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py index 7bd0a174a..b0b039416 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py @@ -50,7 +50,6 @@ from torch.utils.data import DataLoader from utils import str2bool - class _SeedWorkers: def __init__(self, seed: int): self.seed = seed @@ -149,7 +148,7 @@ class AsrDataModule: group.add_argument( "--num-workers", type=int, - default=2, + default=4, help="The number of training dataloader workers that " "collect the batches.", ) @@ -262,31 +261,35 @@ class AsrDataModule: logging.info("About to create train dataset") train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), + input_strategy=OnTheFlyFeatures( + WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) + ) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), cut_transforms=transforms, input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) - ), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) + # if self.args.on_the_fly_feats: + # # NOTE: the PerturbSpeed transform should be added only if we + # # remove it from data prep stage. + # # Add on-the-fly speed perturbation; since originally it would + # # have increased epoch size by 3, we will apply prob 2/3 and use + # # 3x more epochs. + # # Speed perturbation probably should come first before + # # concatenation, but in principle the transforms order doesn't have + # # to be strict (e.g. could be randomized) + # # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # # Drop feats to be on the safe side. + # train = K2SpeechRecognitionDataset( + # cut_transforms=transforms, + # input_strategy=OnTheFlyFeatures( + # WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) + # ), + # input_transforms=input_transforms, + # return_cuts=self.args.return_cuts, + # ) if self.args.bucketing_sampler: logging.info("Using DynamicBucketingSampler.") @@ -322,7 +325,7 @@ class AsrDataModule: sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, - persistent_workers=True, + persistent_workers=True if self.args.num_workers > 0 else False, pin_memory=True, worker_init_fn=worker_init_fn, ) @@ -345,19 +348,26 @@ class AsrDataModule: else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) - - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) + if self.args.bucketing_sampler: + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + else: + valid_sampler = SimpleCutSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) logging.info("About to create dev dataloader") + valid_num_workers = 1 valid_dl = DataLoader( validate, sampler=valid_sampler, batch_size=None, - num_workers=2, - persistent_workers=False, + num_workers=valid_num_workers, + persistent_workers=True if valid_num_workers > 0 else False, ) return valid_dl @@ -450,3 +460,25 @@ class AsrDataModule: self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz" ) return VoiceAssistant_cuts + # def train_cuts_en_vocalnet(self) -> CutSet: + # logging.info("About to get train cuts") + # VoiceAssistant_cuts = load_manifest_lazy( + # self.args.manifest_dir / "cuts_debug.jsonl.gz" + # ) + # return VoiceAssistant_cuts + + # @lru_cache() + # def valid_cuts_en_vocalnet(self) -> CutSet: + # logging.info("About to get valid cuts") + # VoiceAssistant_cuts = load_manifest_lazy( + # self.args.manifest_dir / "cuts_debug.jsonl.gz" + # ) + # return VoiceAssistant_cuts + + # @lru_cache() + # def test_cuts_en_vocalnet(self) -> CutSet: + # logging.info("About to get test cuts") + # VoiceAssistant_cuts = load_manifest_lazy( + # self.args.manifest_dir / "cuts_debug.jsonl.gz" + # ) + # return VoiceAssistant_cuts \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index 0b2642bf0..d23d578c6 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -81,7 +81,10 @@ from utils import ( # filter_uneven_sized_batch, ) DEFAULT_SPEECH_TOKEN = "" - +try: + torch.multiprocessing.set_start_method('spawn') +except RuntimeError: + pass def set_batch_count(model: nn.Module, batch_count: float) -> None: for module in model.modules():