support weighted sampler

This commit is contained in:
marcoyang 2024-08-21 16:47:42 +08:00
parent 1e34f0e2e0
commit fcf06872a2
3 changed files with 104 additions and 27 deletions

View File

@ -10,6 +10,7 @@ stage=-1
stop_stage=4 stop_stage=4
dl_dir=$PWD/download dl_dir=$PWD/download
fbank_dir=data/fbank
# we assume that you have your downloaded the AudioSet and placed # we assume that you have your downloaded the AudioSet and placed
# it under $dl_dir/audioset, the folder structure should look like # it under $dl_dir/audioset, the folder structure should look like
@ -49,7 +50,6 @@ fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set" log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
fbank_dir=data/fbank
if [! -e $fbank_dir/.balanced.done]; then if [! -e $fbank_dir/.balanced.done]; then
python local/generate_audioset_manifest.py \ python local/generate_audioset_manifest.py \
--dataset-dir $dl_dir/audioset \ --dataset-dir $dl_dir/audioset \
@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
touch data/fbank/.musan.done touch data/fbank/.musan.done
fi fi
fi fi
# The following stages are required to do weighted-sampling training
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare for weighted-sampling training"
if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then
lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz
fi
python ./local/compute_weight.py \
--input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \
--output $fbank_dir/sampling_weights_full.txt
fi

View File

@ -31,6 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
PrecomputedFeatures, PrecomputedFeatures,
SimpleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
WeightedSimpleCutSampler,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples, AudioSamples,
@ -99,6 +100,20 @@ class AudioSetATDatamodule:
help="Maximum pooled recordings duration (seconds) in a " help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.", "single batch. You can reduce it if it causes CUDA OOM.",
) )
group.add_argument(
"--weighted-sampler",
type=str2bool,
default=False,
help="When enabled, samples are drawn from by their weights. "
"It cannot be used together with bucketing sampler",
)
group.add_argument(
"--num-samples",
type=int,
default=200000,
help="The number of samples to be drawn in each epoch. Only be used"
"for weighed sampler",
)
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
@ -295,6 +310,9 @@ class AudioSetATDatamodule:
) )
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
assert (
not self.args.weighted_sampler
), "weighted sampling is not supported in bucket sampler"
logging.info("Using DynamicBucketingSampler.") logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler( train_sampler = DynamicBucketingSampler(
cuts_train, cuts_train,
@ -304,13 +322,26 @@ class AudioSetATDatamodule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SimpleCutSampler.") if self.args.weighted_sampler:
train_sampler = SimpleCutSampler( # assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset"
cuts_train, logging.info("Using weighted SimpleCutSampler")
max_duration=self.args.max_duration, weights = self.audioset_sampling_weights()
shuffle=self.args.shuffle, train_sampler = WeightedSimpleCutSampler(
drop_last=self.args.drop_last, cuts_train,
) weights,
num_samples=self.args.num_samples,
max_duration=self.args.max_duration,
shuffle=False, # do not support shuffle
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
drop_last=self.args.drop_last,
)
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
if sampler_state_dict is not None: if sampler_state_dict is not None:
@ -373,11 +404,9 @@ class AudioSetATDatamodule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = AudioTaggingDataset( test = AudioTaggingDataset(
input_strategy=( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats
if self.args.on_the_fly_feats else eval(self.args.input_strategy)(),
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
@ -397,21 +426,30 @@ class AudioSetATDatamodule:
@lru_cache() @lru_cache()
def audioset_train_cuts(self) -> CutSet: def audioset_train_cuts(self) -> CutSet:
logging.info("About to get the audioset training cuts.") logging.info("About to get the audioset training cuts.")
balanced_cuts = load_manifest_lazy( if not self.args.weighted_sampler:
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" balanced_cuts = load_manifest_lazy(
) self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
if self.args.audioset_subset == "full":
unbalanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
)
cuts = CutSet.mux(
balanced_cuts,
unbalanced_cuts,
weights=[20000, 2000000],
stop_early=True,
) )
if self.args.audioset_subset == "full":
unbalanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
)
cuts = CutSet.mux(
balanced_cuts,
unbalanced_cuts,
weights=[20000, 2000000],
stop_early=True,
)
else:
cuts = balanced_cuts
else: else:
cuts = balanced_cuts # assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet"
cuts = load_manifest(
self.args.manifest_dir
/ f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz"
)
logging.info(f"Get {len(cuts)} cuts in total.")
return cuts return cuts
@lru_cache() @lru_cache()
@ -420,3 +458,22 @@ class AudioSetATDatamodule:
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz" self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
) )
@lru_cache()
def audioset_sampling_weights(self):
logging.info(
f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet"
)
weights = []
with open(
self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt",
"r",
) as f:
while True:
line = f.readline()
if not line:
break
weight = float(line.split()[1])
weights.append(weight)
logging.info(f"Get the sampling weight for {len(weights)} cuts")
return weights

View File

@ -789,12 +789,14 @@ def train_one_epoch(
rank=0, rank=0,
) )
num_samples = 0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params)) set_batch_count(model, get_adjusted_batch_count(params))
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = batch["inputs"].size(0) batch_size = batch["inputs"].size(0)
num_samples += batch_size
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
@ -919,6 +921,12 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
if num_samples > params.num_samples:
logging.info(
f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch"
)
break
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
@ -1032,7 +1040,8 @@ def run(rank, world_size, args):
return True return True
train_cuts = train_cuts.filter(remove_short_and_long_utt) if not params.weighted_sampler:
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint