add on the fly feature

This commit is contained in:
root 2025-05-08 19:21:41 -07:00
parent bd2df570ad
commit b20a0d0e35
3 changed files with 71 additions and 36 deletions

View File

@ -174,13 +174,13 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
--prefix gigaspeech --prefix gigaspeech
fi fi
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
ngpu=2 ngpu=4
exp_dir=./qwen_omni/exp_speech2speech_en exp_dir=./qwen_omni/exp_speech2speech_en
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "stage 10: Training Speech2Speech Model" log "stage 10: Training Speech2Speech Model"
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \ --max-duration 150 \
--enable-musan False \ --enable-musan False \
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \ --speech-encoder-path-or-name models/large-v2.pt \
@ -189,6 +189,6 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--deepspeed \ --deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True --on-the-fly-feats True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi fi

View File

@ -50,7 +50,6 @@ from torch.utils.data import DataLoader
from utils import str2bool from utils import str2bool
class _SeedWorkers: class _SeedWorkers:
def __init__(self, seed: int): def __init__(self, seed: int):
self.seed = seed self.seed = seed
@ -149,7 +148,7 @@ class AsrDataModule:
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=4,
help="The number of training dataloader workers that " help="The number of training dataloader workers that "
"collect the batches.", "collect the batches.",
) )
@ -262,31 +261,35 @@ class AsrDataModule:
logging.info("About to create train dataset") logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(), input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
)
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
cut_transforms=transforms, cut_transforms=transforms,
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
if self.args.on_the_fly_feats: # if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we # # NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage. # # remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would # # Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use # # have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs. # # 3x more epochs.
# Speed perturbation probably should come first before # # Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have # # concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized) # # to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side. # # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( # train = K2SpeechRecognitionDataset(
cut_transforms=transforms, # cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( # input_strategy=OnTheFlyFeatures(
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda")) # WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
), # ),
input_transforms=input_transforms, # input_transforms=input_transforms,
return_cuts=self.args.return_cuts, # return_cuts=self.args.return_cuts,
) # )
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.") logging.info("Using DynamicBucketingSampler.")
@ -322,7 +325,7 @@ class AsrDataModule:
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=True, persistent_workers=True if self.args.num_workers > 0 else False,
pin_memory=True, pin_memory=True,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
) )
@ -345,19 +348,26 @@ class AsrDataModule:
else eval(self.args.input_strategy)(), else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
if self.args.bucketing_sampler:
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
) )
else:
valid_sampler = SimpleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
valid_num_workers = 1
valid_dl = DataLoader( valid_dl = DataLoader(
validate, validate,
sampler=valid_sampler, sampler=valid_sampler,
batch_size=None, batch_size=None,
num_workers=2, num_workers=valid_num_workers,
persistent_workers=False, persistent_workers=True if valid_num_workers > 0 else False,
) )
return valid_dl return valid_dl
@ -450,3 +460,25 @@ class AsrDataModule:
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz" self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
) )
return VoiceAssistant_cuts return VoiceAssistant_cuts
# def train_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get train cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
# )
# return VoiceAssistant_cuts
# @lru_cache()
# def valid_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get valid cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
# )
# return VoiceAssistant_cuts
# @lru_cache()
# def test_cuts_en_vocalnet(self) -> CutSet:
# logging.info("About to get test cuts")
# VoiceAssistant_cuts = load_manifest_lazy(
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
# )
# return VoiceAssistant_cuts

View File

@ -81,7 +81,10 @@ from utils import ( # filter_uneven_sized_batch,
) )
DEFAULT_SPEECH_TOKEN = "<speech>" DEFAULT_SPEECH_TOKEN = "<speech>"
try:
torch.multiprocessing.set_start_method('spawn')
except RuntimeError:
pass
def set_batch_count(model: nn.Module, batch_count: float) -> None: def set_batch_count(model: nn.Module, batch_count: float) -> None:
for module in model.modules(): for module in model.modules():