mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge 5beef022852e1dea957530b7f27d467eb7373415 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
2cd098d344
317
egs/imagenet/CLS/swin_transformer/cls_datamodule.py
Normal file
317
egs/imagenet/CLS/swin_transformer/cls_datamodule.py
Normal file
@ -0,0 +1,317 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This is modified from https://github.com/microsoft/Swin-Transformer/blob/main/data/build.py
|
||||
# The default args are copied from https://github.com/microsoft/Swin-Transformer/blob/main/config.py
|
||||
# We adjust the code style as other recipes in icefall.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from icefall.dist import get_rank, get_world_size
|
||||
from icefall.utils import str2bool
|
||||
from timm.data import Mixup, create_transform
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
def _pil_interp(method):
|
||||
if method == "bicubic":
|
||||
return InterpolationMode.BICUBIC
|
||||
elif method == "lanczos":
|
||||
return InterpolationMode.LANCZOS
|
||||
elif method == "hamming":
|
||||
return InterpolationMode.HAMMING
|
||||
else:
|
||||
# default bilinear, do we want to allow nearest?
|
||||
return InterpolationMode.BILINEAR
|
||||
|
||||
import timm.data.transforms as timm_transforms
|
||||
|
||||
timm_transforms._pil_interp = _pil_interp
|
||||
except: # noqa
|
||||
from timm.data.transforms import _pil_interp
|
||||
|
||||
|
||||
class ImageNetClsDataModule:
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
self.rank = get_rank()
|
||||
self.world_size = get_world_size()
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="Image classification data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders -- they control the effective batch sizes, "
|
||||
"sampling strategies, applied data augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--data-path",
|
||||
type=Path,
|
||||
default=Path("imagenet"),
|
||||
help="Path to imagenet dataset,",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Batch size for a single GPU, could be overwritten by command line argument",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--color-jitter",
|
||||
type=float,
|
||||
default=0.4,
|
||||
help="Color jitter factor",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--auto-augment",
|
||||
type=str,
|
||||
default="rand-m9-mstd0.5-inc1",
|
||||
help="AutoAugment policy. 'v0' or 'original'",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--reprob",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="Random erase prob",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--remode",
|
||||
type=str,
|
||||
default="pixel",
|
||||
help="Random erase mode",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--recount",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Random erase count",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--interpolation",
|
||||
type=str,
|
||||
default="bicubic",
|
||||
help="Interpolation to resize image (random, bilinear, bicubic)",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--crop",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use center crop when testing",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--mixup",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Mixup alpha, mixup enabled if > 0",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--cutmix",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Cutmix alpha, cutmix enabled if > 0",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--cutmix-minmax",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Cutmix min/max ratio, overrides alpha and enables cutmix if set",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--mixup-prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Probability of performing mixup or cutmix when either/both is enabled",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--mixup-switch-prob",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Probability of switching to cutmix when both mixup and cutmix enabled",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--mixup-mode",
|
||||
type=str,
|
||||
default="batch",
|
||||
help="How to apply mixup/cutmix params. Per 'batch', 'pair', or 'elem'",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of data loading threads",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--pin-memory",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
|
||||
)
|
||||
|
||||
def build_transform(self, is_training: bool = False):
|
||||
resize_im = self.args.img_size > 32
|
||||
if is_training:
|
||||
# this should always dispatch to transforms_imagenet_train
|
||||
transform = create_transform(
|
||||
input_size=self.args.img_size,
|
||||
is_training=True,
|
||||
color_jitter=self.args.color_jitter
|
||||
if self.args.color_jitter > 0
|
||||
else None,
|
||||
auto_augment=self.args.auto_augment
|
||||
if self.args.auto_augment != "none"
|
||||
else None,
|
||||
re_prob=self.args.reprob,
|
||||
re_mode=self.args.remode,
|
||||
re_count=self.args.recount,
|
||||
interpolation=self.args.interpolation,
|
||||
)
|
||||
if not resize_im:
|
||||
# replace RandomResizedCropAndInterpolation with
|
||||
# RandomCrop
|
||||
transform.transforms[0] = transforms.RandomCrop(
|
||||
self.args.img_size, padding=4
|
||||
)
|
||||
return transform
|
||||
|
||||
t = []
|
||||
if resize_im:
|
||||
if self.args.crop:
|
||||
size = int((256 / 224) * self.args.img_size)
|
||||
t.append(
|
||||
transforms.Resize(
|
||||
size, interpolation=_pil_interp(self.args.interpolation)
|
||||
),
|
||||
# to maintain same ratio w.r.t. 224 images
|
||||
)
|
||||
t.append(transforms.CenterCrop(self.args.img_size))
|
||||
else:
|
||||
t.append(
|
||||
transforms.Resize(
|
||||
(self.args.img_size, self.args.img_size),
|
||||
interpolation=_pil_interp(self.args.interpolation),
|
||||
)
|
||||
)
|
||||
|
||||
t.append(transforms.ToTensor())
|
||||
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
||||
return transforms.Compose(t)
|
||||
|
||||
def build_dataset(self, is_training: bool = False):
|
||||
transform = self.build_transform(is_training)
|
||||
prefix = "train" if is_training else "val"
|
||||
root = os.path.join(self.args.data_path, prefix)
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
return dataset
|
||||
|
||||
def build_train_loader(
|
||||
self, num_classes: int, label_smoothing: Optional[float] = None
|
||||
):
|
||||
assert num_classes == 1000, num_classes
|
||||
dataset_train = self.build_dataset(is_training=True)
|
||||
logging.info(f"rank {self.rank} successfully build train dataset")
|
||||
|
||||
if self.world_size > 1:
|
||||
sampler_train = DistributedSampler(
|
||||
dataset_train,
|
||||
num_replicas=self.world_size,
|
||||
rank=self.rank,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
sampler_train = RandomSampler(dataset_train)
|
||||
|
||||
# TODO: need to set up worker_init_fn?
|
||||
data_loader_train = DataLoader(
|
||||
dataset_train,
|
||||
sampler=sampler_train,
|
||||
batch_size=self.args.batch_size,
|
||||
num_workers=self.args.num_workers,
|
||||
pin_memory=self.args.pin_memory,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# setup mixup / cutmix
|
||||
mixup_fn = None
|
||||
mixup_active = (
|
||||
self.args.mixup > 0
|
||||
or self.args.cutmix > 0.0
|
||||
or self.args.cutmix_minmax is not None
|
||||
)
|
||||
if mixup_active:
|
||||
mixup_fn = Mixup(
|
||||
mixup_alpha=self.args.mixup,
|
||||
cutmix_alpha=self.args.cutmix,
|
||||
cutmix_minmax=self.args.cutmix_minmax,
|
||||
prob=self.args.mixup_prob,
|
||||
switch_prob=self.args.mixup_switch_prob,
|
||||
mode=self.args.mixup_mode,
|
||||
label_smoothing=label_smoothing,
|
||||
num_classes=num_classes,
|
||||
)
|
||||
|
||||
return data_loader_train, mixup_fn
|
||||
|
||||
def build_val_loader(self):
|
||||
dataset_val = self.build_dataset(is_training=False)
|
||||
logging.info(f"rank {self.rank} successfully build val dataset")
|
||||
|
||||
if self.world_size > 1:
|
||||
sampler_val = DistributedSampler(
|
||||
dataset_val, num_replicas=self.world_size, rank=self.rank, shuffle=False
|
||||
)
|
||||
else:
|
||||
sampler_val = SequentialSampler(dataset_val)
|
||||
|
||||
data_loader_val = DataLoader(
|
||||
dataset_val,
|
||||
sampler=sampler_val,
|
||||
batch_size=self.args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.args.num_workers,
|
||||
pin_memory=self.args.pin_memory,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
return data_loader_val
|
1
egs/imagenet/CLS/swin_transformer/optim.py
Symbolic link
1
egs/imagenet/CLS/swin_transformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/optim.py
|
617
egs/imagenet/CLS/swin_transformer/swin_transformer.py
Normal file
617
egs/imagenet/CLS/swin_transformer/swin_transformer.py
Normal file
@ -0,0 +1,617 @@
|
||||
# --------------------------------------------------------
|
||||
# Swin Transformer
|
||||
# Copyright (c) 2021 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Ze Liu
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
# copied from https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
try:
|
||||
import os, sys
|
||||
|
||||
kernel_path = os.path.abspath(os.path.join('..'))
|
||||
sys.path.append(kernel_path)
|
||||
from kernels.window_process.window_process import WindowProcess, WindowProcessReverse
|
||||
|
||||
except:
|
||||
WindowProcess = None
|
||||
WindowProcessReverse = None
|
||||
print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.")
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
||||
|
||||
def flops(self, N):
|
||||
# calculate flops for 1 window with token length of N
|
||||
flops = 0
|
||||
# qkv = self.qkv(x)
|
||||
flops += N * self.dim * 3 * self.dim
|
||||
# attn = (q @ k.transpose(-2, -1))
|
||||
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
||||
# x = (attn @ v)
|
||||
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
||||
# x = self.proj(x)
|
||||
flops += N * self.dim * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
r""" Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resulotion.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
||||
fused_window_process=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
self.fused_window_process = fused_window_process
|
||||
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
if not self.fused_window_process:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
else:
|
||||
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
|
||||
else:
|
||||
shifted_x = x
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
if not self.fused_window_process:
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
|
||||
else:
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||
x = shifted_x
|
||||
x = x.view(B, H * W, C)
|
||||
x = shortcut + self.drop_path(x)
|
||||
|
||||
# FFN
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
||||
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.input_resolution
|
||||
# norm1
|
||||
flops += self.dim * H * W
|
||||
# W-MSA/SW-MSA
|
||||
nW = H * W / self.window_size / self.window_size
|
||||
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
||||
# mlp
|
||||
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
||||
# norm2
|
||||
flops += self.dim * H * W
|
||||
return flops
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = H * W * self.dim
|
||||
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
||||
fused_window_process=False):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
||||
num_heads=num_heads, window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
fused_window_process=fused_window_process)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
for blk in self.blocks:
|
||||
flops += blk.flops()
|
||||
if self.downsample is not None:
|
||||
flops += self.downsample.flops()
|
||||
return flops
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
r""" Image to Patch Embedding
|
||||
|
||||
Args:
|
||||
img_size (int): Image size. Default: 224.
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
Ho, Wo = self.patches_resolution
|
||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||
if self.norm is not None:
|
||||
flops += Ho * Wo * self.embed_dim
|
||||
return flops
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module):
|
||||
r""" Swin Transformer
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
|
||||
Args:
|
||||
img_size (int | tuple(int)): Input image size. Default 224
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
num_classes (int): Number of classes for classification head. Default: 1000
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
||||
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
|
||||
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
use_checkpoint=False, fused_window_process=False, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
|
||||
input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
||||
patches_resolution[1] // (2 ** i_layer)),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
fused_window_process=fused_window_process)
|
||||
self.layers.append(layer)
|
||||
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'absolute_pos_embed'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'relative_position_bias_table'}
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.ape:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
||||
x = torch.flatten(x, 1)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
flops += self.patch_embed.flops()
|
||||
for i, layer in enumerate(self.layers):
|
||||
flops += layer.flops()
|
||||
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
|
||||
flops += self.num_features * self.num_classes
|
||||
return flops
|
924
egs/imagenet/CLS/swin_transformer/train.py
Executable file
924
egs/imagenet/CLS/swin_transformer/train.py
Executable file
@ -0,0 +1,924 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Daniel Povey)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import datetime
|
||||
import time
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from cls_datamodule import ImageNetClsDataModule
|
||||
from optim import Eden, ScaledAdam
|
||||
from utils import AverageMeter, accuracy, fix_random_seed, reduce_tensor
|
||||
from timm.data import Mixup
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import update_averaged_model
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
get_parameter_groups_with_lrs,
|
||||
)
|
||||
from swin_transformer import SwinTransformer
|
||||
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
||||
# Returns the number of batches we would have used so far.
|
||||
# This is for purposes of set_batch_count().
|
||||
return params.batch_idx_train * params.world_size
|
||||
|
||||
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
if hasattr(module, "name"):
|
||||
module.name = name
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--patch-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Patch size. Default: 4",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embed-dim",
|
||||
type=int,
|
||||
default=96,
|
||||
help="Patch embedding dimension. Default: 96",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--depths",
|
||||
type=str,
|
||||
default="2,2,6,2",
|
||||
help="Depth of each Swin Transformer layer.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-heads",
|
||||
type=str,
|
||||
default="3,6,12,24",
|
||||
help="Number of attention heads in different layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--window-size",
|
||||
type=int,
|
||||
default=7,
|
||||
help="Window size. Default: 7",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mlp-ratio",
|
||||
type=float,
|
||||
default=4.0,
|
||||
help="Ratio of mlp hidden dim to embedding dim. Default: 4",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--qkv-bias",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If True, add a learnable bias to query, key, value. Default: True",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--qk-scale",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Override default qk scale of head_dim ** -0.5 if set. Default: None",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ape",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, add absolute position embedding to the patch embedding. Default: False",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--patch-norm",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If True, add normalization after patch embedding. Default: True",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--drop-rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Dropout rate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--drop-path-rate",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Drop path rate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fused-window-process",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="swin_transformer/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=0.025, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
default=7500,
|
||||
help="""Number of steps that affects how rapidly the learning rate
|
||||
decreases. We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-epochs",
|
||||
type=float,
|
||||
default=3.5,
|
||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--label-smoothing",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Label smoothing used in loss computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--print-diagnostics",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--inf-check",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Add hooks to check for infinite module outputs and gradients.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Update the averaged model, namely `model_avg`, after processing
|
||||
this number of batches. `model_avg` is a separate version of model,
|
||||
in which each floating-point parameter is the average of all the
|
||||
parameters from the start of training. Each time we take the average,
|
||||
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- encoder_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warmup period that dictates the decay of the
|
||||
scale on "simple" (un-pruned) loss.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_accuracy": 0.0, # acc1
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
"valid_log_interval": 10,
|
||||
# parameters for SwinTransformer
|
||||
"img_size": 224,
|
||||
"in_chans": 3,
|
||||
"num_classes": 1000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model_avg: nn.Module = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`, `best_valid_loss`,
|
||||
and `best_accuracy` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The scheduler that we are using.
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
"best_accuracy",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer used in the training.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
) -> None:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
batch_time = AverageMeter()
|
||||
loss_meter = AverageMeter()
|
||||
acc1_meter = AverageMeter()
|
||||
acc5_meter = AverageMeter()
|
||||
|
||||
end = time.time()
|
||||
for batch_idx, (images, targets) in enumerate(valid_dl):
|
||||
images = images.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
# compute outputs
|
||||
outputs = model(images)
|
||||
|
||||
# measure accuracy and record loss
|
||||
loss = criterion(outputs, targets)
|
||||
acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
|
||||
|
||||
if world_size > 1:
|
||||
acc1 = reduce_tensor(acc1)
|
||||
acc5 = reduce_tensor(acc5)
|
||||
loss = reduce_tensor(loss)
|
||||
|
||||
loss_meter.update(loss.item(), targets.size(0))
|
||||
acc1_meter.update(acc1.item(), targets.size(0))
|
||||
acc5_meter.update(acc5.item(), targets.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if batch_idx % params.valid_log_interval == 0:
|
||||
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
||||
logging.info(
|
||||
f"Test: [{batch_idx}/{len(valid_dl)}]\t"
|
||||
f"Time {batch_time}\t"
|
||||
f"Loss {loss_meter}\t"
|
||||
f"Acc@1 {acc1_meter}\t"
|
||||
f"Acc@5 {acc5_meter}\t"
|
||||
f"Mem {memory_used:.0f}MB"
|
||||
)
|
||||
|
||||
logging.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}")
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/valid_loss", loss_meter.avg, params.batch_idx_train)
|
||||
tb_writer.add_scalar("train/valid_acc1", acc1_meter.avg, params.batch_idx_train)
|
||||
tb_writer.add_scalar("train/valid_acc5", acc5_meter.avg, params.batch_idx_train)
|
||||
|
||||
if loss_meter.avg < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_meter.avg
|
||||
|
||||
if acc1_meter.avg > params.best_accuracy:
|
||||
params.best_accuracy = acc1_meter.avg
|
||||
logging.info(f"Best accuracy: {params.best_accuracy:.2f}%")
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
mixup_fn: Optional[Mixup] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler, we call step() every step.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
if params.mixup > 0.0:
|
||||
# smoothing is handled with mixup label transform
|
||||
criterion = SoftTargetCrossEntropy()
|
||||
elif params.label_smoothing > 0.0:
|
||||
criterion = LabelSmoothingCrossEntropy(smoothing=params.label_smoothing)
|
||||
else:
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint_impl(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
batch_time = AverageMeter()
|
||||
loss_meter = AverageMeter()
|
||||
|
||||
num_steps = len(train_dl)
|
||||
|
||||
start = time.time()
|
||||
end = time.time()
|
||||
for batch_idx, (images, targets) in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
|
||||
params.batch_idx_train += 1
|
||||
|
||||
images = images.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mixup_fn is not None:
|
||||
images, targets = mixup_fn(images, targets)
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
# compute outputs
|
||||
outputs = model(images)
|
||||
# measure accuracy and record loss
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# summary stats
|
||||
loss_meter.update(loss.item(), targets.size(0))
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
):
|
||||
update_averaged_model(params=params, model_cur=model, model_avg=model_avg)
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
save_bad_model(suffix="-first-warning")
|
||||
saved_bad_model = True
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
||||
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}/{num_steps}, "
|
||||
f"time {batch_time}, "
|
||||
f"loss {loss_meter}, "
|
||||
f"batch size {targets.size(0)}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
f"mem {memory_used:.0f}MB, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss", loss_meter.val, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/averaged_loss", loss_meter.avg, params.batch_idx_train
|
||||
)
|
||||
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
|
||||
epoch_time = time.time() - start
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
|
||||
)
|
||||
|
||||
if loss_meter.avg < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = loss_meter.avg
|
||||
|
||||
|
||||
def _to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
|
||||
def get_model(params):
|
||||
model = SwinTransformer(
|
||||
img_size=params.img_size,
|
||||
patch_size=params.patch_size,
|
||||
in_chans=params.in_chans,
|
||||
num_classes=params.num_classes,
|
||||
embed_dim=params.embed_dim,
|
||||
depths=_to_int_tuple(params.depths),
|
||||
num_heads=_to_int_tuple(params.num_heads),
|
||||
window_size=params.window_size,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
qk_scale=params.qk_scale,
|
||||
drop_rate=params.drop_rate,
|
||||
drop_path_rate=params.drop_path_rate,
|
||||
ape=params.ape,
|
||||
patch_norm=params.patch_norm,
|
||||
fused_window_process=params.fused_window_process,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed, rank)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is currently unavailable.")
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||
lr=params.base_lr, # should have no effect
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
logging.info("Loading optimizer state dict")
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
if (
|
||||
checkpoints
|
||||
and "scheduler" in checkpoints
|
||||
and checkpoints["scheduler"] is not None
|
||||
):
|
||||
logging.info("Loading scheduler state dict")
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
# Create datasets and dataloaders
|
||||
imagenet = ImageNetClsDataModule(params)
|
||||
train_dl, mixup_fn = imagenet.build_train_loader(
|
||||
num_classes=params.num_classes, label_smoothing=params.label_smoothing
|
||||
)
|
||||
valid_dl = imagenet.build_val_loader()
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1, rank)
|
||||
if world_size > 1:
|
||||
# For DistributedSampler
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
train_dl=train_dl,
|
||||
mixup_fn=mixup_fn,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
validate(
|
||||
params=params,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
tb_writer=tb_writer,
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
ImageNetClsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
78
egs/imagenet/CLS/swin_transformer/utils.py
Normal file
78
egs/imagenet/CLS/swin_transformer/utils.py
Normal file
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# We might need to move this file to icefall/utils.py in the future
|
||||
|
||||
|
||||
# Copied from https://github.com/microsoft/Swin-Transformer/blob/main/utils.py
|
||||
def reduce_tensor(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= dist.get_world_size()
|
||||
return rt
|
||||
|
||||
|
||||
def fix_random_seed(random_seed: int, rank: int):
|
||||
seed = random_seed + rank
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/metrics.py
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.val:.4f} (avg: {self.avg:.4f})"
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/metrics.py
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
maxk = min(max(topk), output.size()[1])
|
||||
batch_size = target.size(0)
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
||||
return [
|
||||
correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size
|
||||
for k in topk
|
||||
]
|
201
egs/imagenet/CLS/swin_transformer/validate.py
Executable file
201
egs/imagenet/CLS/swin_transformer/validate.py
Executable file
@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cls_datamodule import ImageNetClsDataModule
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from utils import AverageMeter, accuracy
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
) -> None:
|
||||
"""Run the validation process."""
|
||||
batch_time = AverageMeter()
|
||||
acc1_meter = AverageMeter()
|
||||
acc5_meter = AverageMeter()
|
||||
|
||||
end = time.time()
|
||||
for batch_idx, (images, targets) in enumerate(valid_dl):
|
||||
images = images.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
# compute outputs
|
||||
outputs = model(images)
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
|
||||
|
||||
acc1_meter.update(acc1.item(), targets.size(0))
|
||||
acc5_meter.update(acc5.item(), targets.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
logging.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
ImageNetClsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Validation started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Create datasets and dataloaders
|
||||
imagenet = ImageNetClsDataModule(params)
|
||||
valid_dl = imagenet.build_val_loader()
|
||||
|
||||
validate(
|
||||
params=params,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user