mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Support AudioSet training with weighted sampler (#1727)
This commit is contained in:
parent
5952972294
commit
3fc06cc2b9
@ -35,16 +35,40 @@ python zipformer/train.py \
|
|||||||
--master-port 13455
|
--master-port 13455
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We recommend that you train the model with weighted sampler, as the model converges
|
||||||
|
faster with better performance:
|
||||||
|
|
||||||
|
| Model | mAP |
|
||||||
|
| ------ | ------- |
|
||||||
|
| Zipformer-AT, train with weighted sampler | 46.6 |
|
||||||
|
|
||||||
The evaluation command is:
|
The evaluation command is:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python zipformer/evaluate.py \
|
export CUDA_VISIBLE_DEVICES="4,5,6,7"
|
||||||
--epoch 32 \
|
subset=full
|
||||||
--avg 8 \
|
weighted_sampler=1
|
||||||
--exp-dir zipformer/exp_at_as_full \
|
bucket_sampler=0
|
||||||
--max-duration 500
|
lr_epochs=15
|
||||||
|
|
||||||
|
python zipformer/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--audioset-subset $subset \
|
||||||
|
--num-epochs 120 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--num-events 527 \
|
||||||
|
--lr-epochs $lr_epochs \
|
||||||
|
--exp-dir zipformer/exp_AS_${subset}_weighted_sampler${weighted_sampler} \
|
||||||
|
--weighted-sampler $weighted_sampler \
|
||||||
|
--bucketing-sampler $bucket_sampler \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--enable-musan True \
|
||||||
|
--master-port 13452
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The command for evaluation is the same. The pre-trained model can be downloaded from https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-M-weighted-sampler
|
||||||
|
|
||||||
|
|
||||||
#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M
|
#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M
|
||||||
|
|
||||||
|
73
egs/audioset/AT/local/compute_weight.py
Normal file
73
egs/audioset/AT/local/compute_weight.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file generates the manifest and computes the fbank features for AudioSet
|
||||||
|
dataset. The generated manifests and features are stored in data/fbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import lhotse
|
||||||
|
from lhotse import load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
cuts = load_manifest(args.input_manifest)
|
||||||
|
|
||||||
|
print(f"A total of {len(cuts)} cuts.")
|
||||||
|
|
||||||
|
label_count = [0] * 527 # a total of 527 classes
|
||||||
|
for c in cuts:
|
||||||
|
audio_event = c.supervisions[0].audio_event
|
||||||
|
labels = list(map(int, audio_event.split(";")))
|
||||||
|
for label in labels:
|
||||||
|
label_count[label] += 1
|
||||||
|
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
for c in cuts:
|
||||||
|
audio_event = c.supervisions[0].audio_event
|
||||||
|
labels = list(map(int, audio_event.split(";")))
|
||||||
|
weight = 0
|
||||||
|
for label in labels:
|
||||||
|
weight += 1000 / (label_count[label] + 0.01)
|
||||||
|
f.write(f"{c.id} {weight}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -10,6 +10,7 @@ stage=-1
|
|||||||
stop_stage=4
|
stop_stage=4
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
|
fbank_dir=data/fbank
|
||||||
|
|
||||||
# we assume that you have your downloaded the AudioSet and placed
|
# we assume that you have your downloaded the AudioSet and placed
|
||||||
# it under $dl_dir/audioset, the folder structure should look like
|
# it under $dl_dir/audioset, the folder structure should look like
|
||||||
@ -49,7 +50,6 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
|
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
|
||||||
fbank_dir=data/fbank
|
|
||||||
if [! -e $fbank_dir/.balanced.done]; then
|
if [! -e $fbank_dir/.balanced.done]; then
|
||||||
python local/generate_audioset_manifest.py \
|
python local/generate_audioset_manifest.py \
|
||||||
--dataset-dir $dl_dir/audioset \
|
--dataset-dir $dl_dir/audioset \
|
||||||
@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
touch data/fbank/.musan.done
|
touch data/fbank/.musan.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# The following stages are required to do weighted-sampling training
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "Stage 5: Prepare for weighted-sampling training"
|
||||||
|
if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then
|
||||||
|
lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz
|
||||||
|
fi
|
||||||
|
python ./local/compute_weight.py \
|
||||||
|
--input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \
|
||||||
|
--output $fbank_dir/sampling_weights_full.txt
|
||||||
|
fi
|
||||||
|
@ -31,6 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SimpleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
|
WeightedSimpleCutSampler,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
AudioSamples,
|
AudioSamples,
|
||||||
@ -99,6 +100,20 @@ class AudioSetATDatamodule:
|
|||||||
help="Maximum pooled recordings duration (seconds) in a "
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--weighted-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, samples are drawn from by their weights. "
|
||||||
|
"It cannot be used together with bucketing sampler",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-samples",
|
||||||
|
type=int,
|
||||||
|
default=200000,
|
||||||
|
help="The number of samples to be drawn in each epoch. Only be used"
|
||||||
|
"for weighed sampler",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -295,6 +310,9 @@ class AudioSetATDatamodule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
if self.args.bucketing_sampler:
|
||||||
|
assert (
|
||||||
|
not self.args.weighted_sampler
|
||||||
|
), "weighted sampling is not supported in bucket sampler"
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
train_sampler = DynamicBucketingSampler(
|
train_sampler = DynamicBucketingSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
@ -304,13 +322,26 @@ class AudioSetATDatamodule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
if self.args.weighted_sampler:
|
||||||
train_sampler = SimpleCutSampler(
|
# assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset"
|
||||||
cuts_train,
|
logging.info("Using weighted SimpleCutSampler")
|
||||||
max_duration=self.args.max_duration,
|
weights = self.audioset_sampling_weights()
|
||||||
shuffle=self.args.shuffle,
|
train_sampler = WeightedSimpleCutSampler(
|
||||||
drop_last=self.args.drop_last,
|
cuts_train,
|
||||||
)
|
weights,
|
||||||
|
num_samples=self.args.num_samples,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False, # do not support shuffle
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SimpleCutSampler.")
|
||||||
|
train_sampler = SimpleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
if sampler_state_dict is not None:
|
if sampler_state_dict is not None:
|
||||||
@ -373,11 +404,9 @@ class AudioSetATDatamodule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = AudioTaggingDataset(
|
test = AudioTaggingDataset(
|
||||||
input_strategy=(
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
if self.args.on_the_fly_feats
|
||||||
if self.args.on_the_fly_feats
|
else eval(self.args.input_strategy)(),
|
||||||
else eval(self.args.input_strategy)()
|
|
||||||
),
|
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
@ -397,21 +426,30 @@ class AudioSetATDatamodule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def audioset_train_cuts(self) -> CutSet:
|
def audioset_train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get the audioset training cuts.")
|
logging.info("About to get the audioset training cuts.")
|
||||||
balanced_cuts = load_manifest_lazy(
|
if not self.args.weighted_sampler:
|
||||||
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
|
balanced_cuts = load_manifest_lazy(
|
||||||
)
|
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
|
||||||
if self.args.audioset_subset == "full":
|
|
||||||
unbalanced_cuts = load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
|
|
||||||
)
|
|
||||||
cuts = CutSet.mux(
|
|
||||||
balanced_cuts,
|
|
||||||
unbalanced_cuts,
|
|
||||||
weights=[20000, 2000000],
|
|
||||||
stop_early=True,
|
|
||||||
)
|
)
|
||||||
|
if self.args.audioset_subset == "full":
|
||||||
|
unbalanced_cuts = load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
|
||||||
|
)
|
||||||
|
cuts = CutSet.mux(
|
||||||
|
balanced_cuts,
|
||||||
|
unbalanced_cuts,
|
||||||
|
weights=[20000, 2000000],
|
||||||
|
stop_early=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cuts = balanced_cuts
|
||||||
else:
|
else:
|
||||||
cuts = balanced_cuts
|
# assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet"
|
||||||
|
cuts = load_manifest(
|
||||||
|
self.args.manifest_dir
|
||||||
|
/ f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz"
|
||||||
|
)
|
||||||
|
logging.info(f"Get {len(cuts)} cuts in total.")
|
||||||
|
|
||||||
return cuts
|
return cuts
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@ -420,3 +458,22 @@ class AudioSetATDatamodule:
|
|||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
|
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def audioset_sampling_weights(self):
|
||||||
|
logging.info(
|
||||||
|
f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet"
|
||||||
|
)
|
||||||
|
weights = []
|
||||||
|
with open(
|
||||||
|
self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt",
|
||||||
|
"r",
|
||||||
|
) as f:
|
||||||
|
while True:
|
||||||
|
line = f.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
weight = float(line.split()[1])
|
||||||
|
weights.append(weight)
|
||||||
|
logging.info(f"Get the sampling weight for {len(weights)} cuts")
|
||||||
|
return weights
|
||||||
|
@ -789,12 +789,14 @@ def train_one_epoch(
|
|||||||
rank=0,
|
rank=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_samples = 0
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx % 10 == 0:
|
if batch_idx % 10 == 0:
|
||||||
set_batch_count(model, get_adjusted_batch_count(params))
|
set_batch_count(model, get_adjusted_batch_count(params))
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = batch["inputs"].size(0)
|
batch_size = batch["inputs"].size(0)
|
||||||
|
num_samples += batch_size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
@ -919,6 +921,12 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if num_samples > params.num_samples:
|
||||||
|
logging.info(
|
||||||
|
f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
@ -1032,7 +1040,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
if not params.weighted_sampler:
|
||||||
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
|
Loading…
x
Reference in New Issue
Block a user