mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
318 lines
10 KiB
Python
318 lines
10 KiB
Python
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
# This is modified from https://github.com/microsoft/Swin-Transformer/blob/main/data/build.py
|
|
# The default args are copied from https://github.com/microsoft/Swin-Transformer/blob/main/config.py
|
|
# We adjust the code style as other recipes in icefall.
|
|
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from icefall.dist import get_rank, get_world_size
|
|
from icefall.utils import str2bool
|
|
from timm.data import Mixup, create_transform
|
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torchvision import datasets, transforms
|
|
|
|
try:
|
|
from torchvision.transforms import InterpolationMode
|
|
|
|
def _pil_interp(method):
|
|
if method == "bicubic":
|
|
return InterpolationMode.BICUBIC
|
|
elif method == "lanczos":
|
|
return InterpolationMode.LANCZOS
|
|
elif method == "hamming":
|
|
return InterpolationMode.HAMMING
|
|
else:
|
|
# default bilinear, do we want to allow nearest?
|
|
return InterpolationMode.BILINEAR
|
|
|
|
import timm.data.transforms as timm_transforms
|
|
|
|
timm_transforms._pil_interp = _pil_interp
|
|
except: # noqa
|
|
from timm.data.transforms import _pil_interp
|
|
|
|
|
|
class ImageNetClsDataModule:
|
|
def __init__(self, args: argparse.Namespace):
|
|
self.args = args
|
|
self.rank = get_rank()
|
|
self.world_size = get_world_size()
|
|
|
|
@classmethod
|
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
|
group = parser.add_argument_group(
|
|
title="Image classification data related options",
|
|
description="These options are used for the preparation of "
|
|
"PyTorch DataLoaders -- they control the effective batch sizes, "
|
|
"sampling strategies, applied data augmentations, etc.",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--data-path",
|
|
type=Path,
|
|
default=Path("imagenet"),
|
|
help="Path to imagenet dataset,",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=128,
|
|
help="Batch size for a single GPU, could be overwritten by command line argument",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--color-jitter",
|
|
type=float,
|
|
default=0.4,
|
|
help="Color jitter factor",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--auto-augment",
|
|
type=str,
|
|
default="rand-m9-mstd0.5-inc1",
|
|
help="AutoAugment policy. 'v0' or 'original'",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--reprob",
|
|
type=float,
|
|
default=0.25,
|
|
help="Random erase prob",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--remode",
|
|
type=str,
|
|
default="pixel",
|
|
help="Random erase mode",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--recount",
|
|
type=int,
|
|
default=1,
|
|
help="Random erase count",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--interpolation",
|
|
type=str,
|
|
default="bicubic",
|
|
help="Interpolation to resize image (random, bilinear, bicubic)",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--crop",
|
|
type=str2bool,
|
|
default=True,
|
|
help="Whether to use center crop when testing",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--mixup",
|
|
type=float,
|
|
default=0.8,
|
|
help="Mixup alpha, mixup enabled if > 0",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--cutmix",
|
|
type=float,
|
|
default=1.0,
|
|
help="Cutmix alpha, cutmix enabled if > 0",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--cutmix-minmax",
|
|
type=float,
|
|
default=None,
|
|
help="Cutmix min/max ratio, overrides alpha and enables cutmix if set",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--mixup-prob",
|
|
type=float,
|
|
default=1.0,
|
|
help="Probability of performing mixup or cutmix when either/both is enabled",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--mixup-switch-prob",
|
|
type=float,
|
|
default=0.5,
|
|
help="Probability of switching to cutmix when both mixup and cutmix enabled",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--mixup-mode",
|
|
type=str,
|
|
default="batch",
|
|
help="How to apply mixup/cutmix params. Per 'batch', 'pair', or 'elem'",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--num-workers",
|
|
type=int,
|
|
default=8,
|
|
help="Number of data loading threads",
|
|
)
|
|
|
|
group.add_argument(
|
|
"--pin-memory",
|
|
type=str2bool,
|
|
default=True,
|
|
help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
|
|
)
|
|
|
|
def build_transform(self, is_training: bool = False):
|
|
resize_im = self.args.img_size > 32
|
|
if is_training:
|
|
# this should always dispatch to transforms_imagenet_train
|
|
transform = create_transform(
|
|
input_size=self.args.img_size,
|
|
is_training=True,
|
|
color_jitter=self.args.color_jitter
|
|
if self.args.color_jitter > 0
|
|
else None,
|
|
auto_augment=self.args.auto_augment
|
|
if self.args.auto_augment != "none"
|
|
else None,
|
|
re_prob=self.args.reprob,
|
|
re_mode=self.args.remode,
|
|
re_count=self.args.recount,
|
|
interpolation=self.args.interpolation,
|
|
)
|
|
if not resize_im:
|
|
# replace RandomResizedCropAndInterpolation with
|
|
# RandomCrop
|
|
transform.transforms[0] = transforms.RandomCrop(
|
|
self.args.img_size, padding=4
|
|
)
|
|
return transform
|
|
|
|
t = []
|
|
if resize_im:
|
|
if self.args.crop:
|
|
size = int((256 / 224) * self.args.img_size)
|
|
t.append(
|
|
transforms.Resize(
|
|
size, interpolation=_pil_interp(self.args.interpolation)
|
|
),
|
|
# to maintain same ratio w.r.t. 224 images
|
|
)
|
|
t.append(transforms.CenterCrop(self.args.img_size))
|
|
else:
|
|
t.append(
|
|
transforms.Resize(
|
|
(self.args.img_size, self.args.img_size),
|
|
interpolation=_pil_interp(self.args.interpolation),
|
|
)
|
|
)
|
|
|
|
t.append(transforms.ToTensor())
|
|
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
|
return transforms.Compose(t)
|
|
|
|
def build_dataset(self, is_training: bool = False):
|
|
transform = self.build_transform(is_training)
|
|
prefix = "train" if is_training else "val"
|
|
root = os.path.join(self.args.data_path, prefix)
|
|
dataset = datasets.ImageFolder(root, transform=transform)
|
|
return dataset
|
|
|
|
def build_train_loader(
|
|
self, num_classes: int, label_smoothing: Optional[float] = None
|
|
):
|
|
assert num_classes == 1000, num_classes
|
|
dataset_train = self.build_dataset(is_training=True)
|
|
logging.info(f"rank {self.rank} successfully build train dataset")
|
|
|
|
if self.world_size > 1:
|
|
sampler_train = DistributedSampler(
|
|
dataset_train,
|
|
num_replicas=self.world_size,
|
|
rank=self.rank,
|
|
shuffle=True,
|
|
)
|
|
else:
|
|
sampler_train = RandomSampler(dataset_train)
|
|
|
|
# TODO: need to set up worker_init_fn?
|
|
data_loader_train = DataLoader(
|
|
dataset_train,
|
|
sampler=sampler_train,
|
|
batch_size=self.args.batch_size,
|
|
num_workers=self.args.num_workers,
|
|
pin_memory=self.args.pin_memory,
|
|
drop_last=True,
|
|
)
|
|
|
|
# setup mixup / cutmix
|
|
mixup_fn = None
|
|
mixup_active = (
|
|
self.args.mixup > 0
|
|
or self.args.cutmix > 0.0
|
|
or self.args.cutmix_minmax is not None
|
|
)
|
|
if mixup_active:
|
|
mixup_fn = Mixup(
|
|
mixup_alpha=self.args.mixup,
|
|
cutmix_alpha=self.args.cutmix,
|
|
cutmix_minmax=self.args.cutmix_minmax,
|
|
prob=self.args.mixup_prob,
|
|
switch_prob=self.args.mixup_switch_prob,
|
|
mode=self.args.mixup_mode,
|
|
label_smoothing=label_smoothing,
|
|
num_classes=num_classes,
|
|
)
|
|
|
|
return data_loader_train, mixup_fn
|
|
|
|
def build_val_loader(self):
|
|
dataset_val = self.build_dataset(is_training=False)
|
|
logging.info(f"rank {self.rank} successfully build val dataset")
|
|
|
|
if self.world_size > 1:
|
|
sampler_val = DistributedSampler(
|
|
dataset_val, num_replicas=self.world_size, rank=self.rank, shuffle=False
|
|
)
|
|
else:
|
|
sampler_val = SequentialSampler(dataset_val)
|
|
|
|
data_loader_val = DataLoader(
|
|
dataset_val,
|
|
sampler=sampler_val,
|
|
batch_size=self.args.batch_size,
|
|
shuffle=False,
|
|
num_workers=self.args.num_workers,
|
|
pin_memory=self.args.pin_memory,
|
|
drop_last=False,
|
|
)
|
|
|
|
return data_loader_val
|