mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add on the fly feature
This commit is contained in:
parent
bd2df570ad
commit
b20a0d0e35
@ -174,13 +174,13 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
--prefix gigaspeech
|
--prefix gigaspeech
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
||||||
ngpu=2
|
ngpu=4
|
||||||
exp_dir=./qwen_omni/exp_speech2speech_en
|
exp_dir=./qwen_omni/exp_speech2speech_en
|
||||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||||
log "stage 10: Training Speech2Speech Model"
|
log "stage 10: Training Speech2Speech Model"
|
||||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
--max-duration 50 \
|
--max-duration 150 \
|
||||||
--enable-musan False \
|
--enable-musan False \
|
||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
--speech-encoder-path-or-name models/large-v2.pt \
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
@ -189,6 +189,6 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
|||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True --on-the-fly-feats True \
|
||||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||||
fi
|
fi
|
||||||
|
@ -50,7 +50,6 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from utils import str2bool
|
from utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
def __init__(self, seed: int):
|
def __init__(self, seed: int):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
@ -149,7 +148,7 @@ class AsrDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=4,
|
||||||
help="The number of training dataloader workers that "
|
help="The number of training dataloader workers that "
|
||||||
"collect the batches.",
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
@ -262,31 +261,35 @@ class AsrDataModule:
|
|||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
input_strategy=eval(self.args.input_strategy)(),
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||||
|
)
|
||||||
|
if self.args.on_the_fly_feats
|
||||||
|
else eval(self.args.input_strategy)(),
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
# if self.args.on_the_fly_feats:
|
||||||
# NOTE: the PerturbSpeed transform should be added only if we
|
# # NOTE: the PerturbSpeed transform should be added only if we
|
||||||
# remove it from data prep stage.
|
# # remove it from data prep stage.
|
||||||
# Add on-the-fly speed perturbation; since originally it would
|
# # Add on-the-fly speed perturbation; since originally it would
|
||||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
# # have increased epoch size by 3, we will apply prob 2/3 and use
|
||||||
# 3x more epochs.
|
# # 3x more epochs.
|
||||||
# Speed perturbation probably should come first before
|
# # Speed perturbation probably should come first before
|
||||||
# concatenation, but in principle the transforms order doesn't have
|
# # concatenation, but in principle the transforms order doesn't have
|
||||||
# to be strict (e.g. could be randomized)
|
# # to be strict (e.g. could be randomized)
|
||||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
# # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||||
# Drop feats to be on the safe side.
|
# # Drop feats to be on the safe side.
|
||||||
train = K2SpeechRecognitionDataset(
|
# train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
# cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(
|
# input_strategy=OnTheFlyFeatures(
|
||||||
WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
# WhisperFbank(WhisperFbankConfig(num_filters=80, device="cuda"))
|
||||||
),
|
# ),
|
||||||
input_transforms=input_transforms,
|
# input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
# return_cuts=self.args.return_cuts,
|
||||||
)
|
# )
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
if self.args.bucketing_sampler:
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
@ -322,7 +325,7 @@ class AsrDataModule:
|
|||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
persistent_workers=True,
|
persistent_workers=True if self.args.num_workers > 0 else False,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
)
|
)
|
||||||
@ -345,19 +348,26 @@ class AsrDataModule:
|
|||||||
else eval(self.args.input_strategy)(),
|
else eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
valid_sampler = SimpleCutSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
|
valid_num_workers = 1
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
validate,
|
validate,
|
||||||
sampler=valid_sampler,
|
sampler=valid_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=2,
|
num_workers=valid_num_workers,
|
||||||
persistent_workers=False,
|
persistent_workers=True if valid_num_workers > 0 else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return valid_dl
|
return valid_dl
|
||||||
@ -450,3 +460,25 @@ class AsrDataModule:
|
|||||||
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
|
self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz"
|
||||||
)
|
)
|
||||||
return VoiceAssistant_cuts
|
return VoiceAssistant_cuts
|
||||||
|
# def train_cuts_en_vocalnet(self) -> CutSet:
|
||||||
|
# logging.info("About to get train cuts")
|
||||||
|
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||||
|
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
||||||
|
# )
|
||||||
|
# return VoiceAssistant_cuts
|
||||||
|
|
||||||
|
# @lru_cache()
|
||||||
|
# def valid_cuts_en_vocalnet(self) -> CutSet:
|
||||||
|
# logging.info("About to get valid cuts")
|
||||||
|
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||||
|
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
||||||
|
# )
|
||||||
|
# return VoiceAssistant_cuts
|
||||||
|
|
||||||
|
# @lru_cache()
|
||||||
|
# def test_cuts_en_vocalnet(self) -> CutSet:
|
||||||
|
# logging.info("About to get test cuts")
|
||||||
|
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||||
|
# self.args.manifest_dir / "cuts_debug.jsonl.gz"
|
||||||
|
# )
|
||||||
|
# return VoiceAssistant_cuts
|
@ -81,7 +81,10 @@ from utils import ( # filter_uneven_sized_batch,
|
|||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||||
|
try:
|
||||||
|
torch.multiprocessing.set_start_method('spawn')
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
def set_batch_count(model: nn.Module, batch_count: float) -> None:
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user