mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add on the fly feature
This commit is contained in:
parent
bd2df570ad
commit
b20a0d0e35
@ -174,13 +174,13 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
--prefix gigaspeech
|
||||
fi
|
||||
|
||||
|
||||
ngpu=2
|
||||
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
||||
ngpu=4
|
||||
exp_dir=./qwen_omni/exp_speech2speech_en
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "stage 10: Training Speech2Speech Model"
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--max-duration 150 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
@ -189,6 +189,6 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-flash-attn True --on-the-fly-feats True \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
fi
|
||||
|
@ -50,7 +50,6 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
@ -149,7 +148,7 @@ class AsrDataModule:
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
default=4,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
@ -262,31 +261,35 @@ class AsrDataModule:
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
# if self.args.on_the_fly_feats:
|
||||
# # NOTE: the PerturbSpeed transform should be added only if we
|
||||
# # remove it from data prep stage.
|
||||
# # Add on-the-fly speed perturbation; since originally it would
|
||||
# # have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# # 3x more epochs.
|
||||
# # Speed perturbation probably should come first before
|
||||
# # concatenation, but in principle the transforms order doesn't have
|
||||
# # to be strict (e.g. could be randomized)
|
||||
# # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# # Drop feats to be on the safe side.
|
||||
# train = K2SpeechRecognitionDataset(
|
||||
# cut_transforms=transforms,
|
||||
# input_strategy=OnTheFlyFeatures(
|
||||
# WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||
# ),
|
||||
# input_transforms=input_transforms,
|
||||
# return_cuts=self.args.return_cuts,
|
||||
# )
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
@ -322,7 +325,7 @@ class AsrDataModule:
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=True,
|
||||
persistent_workers=True if self.args.num_workers > 0 else False,
|
||||
pin_memory=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
@ -345,19 +348,26 @@ class AsrDataModule:
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
else:
|
||||
valid_sampler = SimpleCutSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_num_workers = 1
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
num_workers=valid_num_workers,
|
||||
persistent_workers=True if valid_num_workers > 0 else False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
@ -450,3 +460,25 @@ class AsrDataModule:
|
||||
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
|
||||
)
|
||||
return VoiceAssistant_cuts
|
||||
# def train_cuts_en_vocalnet(self) -> CutSet:
|
||||
# logging.info("About to get train cuts")
|
||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
||||
# )
|
||||
# return VoiceAssistant_cuts
|
||||
|
||||
# @lru_cache()
|
||||
# def valid_cuts_en_vocalnet(self) -> CutSet:
|
||||
# logging.info("About to get valid cuts")
|
||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
||||
# )
|
||||
# return VoiceAssistant_cuts
|
||||
|
||||
# @lru_cache()
|
||||
# def test_cuts_en_vocalnet(self) -> CutSet:
|
||||
# logging.info("About to get test cuts")
|
||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
||||
# )
|
||||
# return VoiceAssistant_cuts
|
@ -81,7 +81,10 @@ from utils import ( # filter_uneven_sized_batch,
|
||||
)
|
||||
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
|
||||
try:
|
||||
torch.multiprocessing.set_start_method('spawn')
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||
for module in model.modules():
|
||||
|
Loading…
x
Reference in New Issue
Block a user