diff --git a/egs/imagenet/CLS/swin_transformer/cls_datamodule.py b/egs/imagenet/CLS/swin_transformer/cls_datamodule.py new file mode 100644 index 000000000..8628cdd09 --- /dev/null +++ b/egs/imagenet/CLS/swin_transformer/cls_datamodule.py @@ -0,0 +1,317 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This is modified from https://github.com/microsoft/Swin-Transformer/blob/main/data/build.py +# The default args are copied from https://github.com/microsoft/Swin-Transformer/blob/main/config.py +# We adjust the code style as other recipes in icefall. + + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +from icefall.dist import get_rank, get_world_size +from icefall.utils import str2bool +from timm.data import Mixup, create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +from torchvision import datasets, transforms + +try: + from torchvision.transforms import InterpolationMode + + def _pil_interp(method): + if method == "bicubic": + return InterpolationMode.BICUBIC + elif method == "lanczos": + return InterpolationMode.LANCZOS + elif method == "hamming": + return InterpolationMode.HAMMING + else: + # default bilinear, do we want to allow nearest? + return InterpolationMode.BILINEAR + + import timm.data.transforms as timm_transforms + + timm_transforms._pil_interp = _pil_interp +except: # noqa + from timm.data.transforms import _pil_interp + + +class ImageNetClsDataModule: + def __init__(self, args: argparse.Namespace): + self.args = args + self.rank = get_rank() + self.world_size = get_world_size() + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Image classification data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders -- they control the effective batch sizes, " + "sampling strategies, applied data augmentations, etc.", + ) + + group.add_argument( + "--data-path", + type=Path, + default=Path("imagenet"), + help="Path to imagenet dataset,", + ) + + group.add_argument( + "--batch-size", + type=int, + default=128, + help="Batch size for a single GPU, could be overwritten by command line argument", + ) + + group.add_argument( + "--color-jitter", + type=float, + default=0.4, + help="Color jitter factor", + ) + + group.add_argument( + "--auto-augment", + type=str, + default="rand-m9-mstd0.5-inc1", + help="AutoAugment policy. 'v0' or 'original'", + ) + + group.add_argument( + "--reprob", + type=float, + default=0.25, + help="Random erase prob", + ) + + group.add_argument( + "--remode", + type=str, + default="pixel", + help="Random erase mode", + ) + + group.add_argument( + "--recount", + type=int, + default=1, + help="Random erase count", + ) + + group.add_argument( + "--interpolation", + type=str, + default="bicubic", + help="Interpolation to resize image (random, bilinear, bicubic)", + ) + + group.add_argument( + "--crop", + type=str2bool, + default=True, + help="Whether to use center crop when testing", + ) + + group.add_argument( + "--mixup", + type=float, + default=0.8, + help="Mixup alpha, mixup enabled if > 0", + ) + + group.add_argument( + "--cutmix", + type=float, + default=1.0, + help="Cutmix alpha, cutmix enabled if > 0", + ) + + group.add_argument( + "--cutmix-minmax", + type=float, + default=None, + help="Cutmix min/max ratio, overrides alpha and enables cutmix if set", + ) + + group.add_argument( + "--mixup-prob", + type=float, + default=1.0, + help="Probability of performing mixup or cutmix when either/both is enabled", + ) + + group.add_argument( + "--mixup-switch-prob", + type=float, + default=0.5, + help="Probability of switching to cutmix when both mixup and cutmix enabled", + ) + + group.add_argument( + "--mixup-mode", + type=str, + default="batch", + help="How to apply mixup/cutmix params. Per 'batch', 'pair', or 'elem'", + ) + + group.add_argument( + "--num-workers", + type=int, + default=8, + help="Number of data loading threads", + ) + + group.add_argument( + "--pin-memory", + type=str2bool, + default=True, + help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", + ) + + def build_transform(self, is_training: bool = False): + resize_im = self.args.img_size > 32 + if is_training: + # this should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=self.args.img_size, + is_training=True, + color_jitter=self.args.color_jitter + if self.args.color_jitter > 0 + else None, + auto_augment=self.args.auto_augment + if self.args.auto_augment != "none" + else None, + re_prob=self.args.reprob, + re_mode=self.args.remode, + re_count=self.args.recount, + interpolation=self.args.interpolation, + ) + if not resize_im: + # replace RandomResizedCropAndInterpolation with + # RandomCrop + transform.transforms[0] = transforms.RandomCrop( + self.args.img_size, padding=4 + ) + return transform + + t = [] + if resize_im: + if self.args.crop: + size = int((256 / 224) * self.args.img_size) + t.append( + transforms.Resize( + size, interpolation=_pil_interp(self.args.interpolation) + ), + # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(self.args.img_size)) + else: + t.append( + transforms.Resize( + (self.args.img_size, self.args.img_size), + interpolation=_pil_interp(self.args.interpolation), + ) + ) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + return transforms.Compose(t) + + def build_dataset(self, is_training: bool = False): + transform = self.build_transform(is_training) + prefix = "train" if is_training else "val" + root = os.path.join(self.args.data_path, prefix) + dataset = datasets.ImageFolder(root, transform=transform) + return dataset + + def build_train_loader( + self, num_classes: int, label_smoothing: Optional[float] = None + ): + assert num_classes == 1000, num_classes + dataset_train = self.build_dataset(is_training=True) + logging.info(f"rank {self.rank} successfully build train dataset") + + if self.world_size > 1: + sampler_train = DistributedSampler( + dataset_train, + num_replicas=self.world_size, + rank=self.rank, + shuffle=True, + ) + else: + sampler_train = RandomSampler(dataset_train) + + # TODO: need to set up worker_init_fn? + data_loader_train = DataLoader( + dataset_train, + sampler=sampler_train, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=self.args.pin_memory, + drop_last=True, + ) + + # setup mixup / cutmix + mixup_fn = None + mixup_active = ( + self.args.mixup > 0 + or self.args.cutmix > 0.0 + or self.args.cutmix_minmax is not None + ) + if mixup_active: + mixup_fn = Mixup( + mixup_alpha=self.args.mixup, + cutmix_alpha=self.args.cutmix, + cutmix_minmax=self.args.cutmix_minmax, + prob=self.args.mixup_prob, + switch_prob=self.args.mixup_switch_prob, + mode=self.args.mixup_mode, + label_smoothing=label_smoothing, + num_classes=num_classes, + ) + + return data_loader_train, mixup_fn + + def build_val_loader(self): + dataset_val = self.build_dataset(is_training=False) + logging.info(f"rank {self.rank} successfully build val dataset") + + if self.world_size > 1: + sampler_val = DistributedSampler( + dataset_val, num_replicas=self.world_size, rank=self.rank, shuffle=False + ) + else: + sampler_val = SequentialSampler(dataset_val) + + data_loader_val = DataLoader( + dataset_val, + sampler=sampler_val, + batch_size=self.args.batch_size, + shuffle=False, + num_workers=self.args.num_workers, + pin_memory=self.args.pin_memory, + drop_last=False, + ) + + return data_loader_val diff --git a/egs/imagenet/CLS/swin_transformer/optim.py b/egs/imagenet/CLS/swin_transformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/imagenet/CLS/swin_transformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/imagenet/CLS/swin_transformer/swin_transformer.py b/egs/imagenet/CLS/swin_transformer/swin_transformer.py new file mode 100644 index 000000000..19d958a62 --- /dev/null +++ b/egs/imagenet/CLS/swin_transformer/swin_transformer.py @@ -0,0 +1,617 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- + + +# copied from https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +try: + import os, sys + + kernel_path = os.path.abspath(os.path.join('..')) + sys.path.append(kernel_path) + from kernels.window_process.window_process import WindowProcess, WindowProcessReverse + +except: + WindowProcess = None + WindowProcessReverse = None + print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.") + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + fused_window_process=False): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + self.fused_window_process = fused_window_process + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + if not self.fused_window_process: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + else: + x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) + else: + shifted_x = x + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + + # reverse cyclic shift + if self.shift_size > 0: + if not self.fused_window_process: + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size) + else: + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + x = shifted_x + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(x) + + # FFN + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + fused_window_process=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + fused_window_process=fused_window_process) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, fused_window_process=False, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + fused_window_process=fused_window_process) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops diff --git a/egs/imagenet/CLS/swin_transformer/train.py b/egs/imagenet/CLS/swin_transformer/train.py new file mode 100755 index 000000000..14f760b6c --- /dev/null +++ b/egs/imagenet/CLS/swin_transformer/train.py @@ -0,0 +1,924 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import copy +import logging +import datetime +import time +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + +import optim +import torch +import torch.multiprocessing as mp +from cls_datamodule import ImageNetClsDataModule +from optim import Eden, ScaledAdam +from utils import AverageMeter, accuracy, fix_random_seed, reduce_tensor +from timm.data import Mixup +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from torch import nn +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import update_averaged_model +from icefall.hooks import register_inf_check_hooks +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, + str2bool, + get_parameter_groups_with_lrs, +) +from swin_transformer import SwinTransformer + + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # Returns the number of batches we would have used so far. + # This is for purposes of set_batch_count(). + return params.batch_idx_train * params.world_size + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--patch-size", + type=int, + default=4, + help="Patch size. Default: 4", + ) + + parser.add_argument( + "--embed-dim", + type=int, + default=96, + help="Patch embedding dimension. Default: 96", + ) + + parser.add_argument( + "--depths", + type=str, + default="2,2,6,2", + help="Depth of each Swin Transformer layer.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="3,6,12,24", + help="Number of attention heads in different layers.", + ) + + parser.add_argument( + "--window-size", + type=int, + default=7, + help="Window size. Default: 7", + ) + + parser.add_argument( + "--mlp-ratio", + type=float, + default=4.0, + help="Ratio of mlp hidden dim to embedding dim. Default: 4", + ) + + parser.add_argument( + "--qkv-bias", + type=str2bool, + default=True, + help="If True, add a learnable bias to query, key, value. Default: True", + ) + + parser.add_argument( + "--qk-scale", + type=float, + default=None, + help="Override default qk scale of head_dim ** -0.5 if set. Default: None", + ) + + parser.add_argument( + "--ape", + type=str2bool, + default=False, + help="If True, add absolute position embedding to the patch embedding. Default: False", + ) + + parser.add_argument( + "--patch-norm", + type=str2bool, + default=True, + help="If True, add normalization after patch embedding. Default: True", + ) + + parser.add_argument( + "--drop-rate", + type=float, + default=0.0, + help="Dropout rate", + ) + + parser.add_argument( + "--drop-path-rate", + type=float, + default=0.1, + help="Drop path rate", + ) + + parser.add_argument( + "--fused-window-process", + type=str2bool, + default=False, + help="If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="swin_transformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.025, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--label-smoothing", + type=float, + default=0.1, + help="Label smoothing used in loss computation", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_accuracy": 0.0, # acc1 + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + "valid_log_interval": 10, + # parameters for SwinTransformer + "img_size": 224, + "in_chans": 3, + "num_classes": 1000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, `best_valid_loss`, + and `best_accuracy` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + "best_accuracy", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +@torch.no_grad() +def validate( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + tb_writer: Optional[SummaryWriter] = None, +) -> None: + """Run the validation process.""" + model.eval() + + criterion = torch.nn.CrossEntropyLoss() + + batch_time = AverageMeter() + loss_meter = AverageMeter() + acc1_meter = AverageMeter() + acc5_meter = AverageMeter() + + end = time.time() + for batch_idx, (images, targets) in enumerate(valid_dl): + images = images.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + # compute outputs + outputs = model(images) + + # measure accuracy and record loss + loss = criterion(outputs, targets) + acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) + + if world_size > 1: + acc1 = reduce_tensor(acc1) + acc5 = reduce_tensor(acc5) + loss = reduce_tensor(loss) + + loss_meter.update(loss.item(), targets.size(0)) + acc1_meter.update(acc1.item(), targets.size(0)) + acc5_meter.update(acc5.item(), targets.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if batch_idx % params.valid_log_interval == 0: + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + logging.info( + f"Test: [{batch_idx}/{len(valid_dl)}]\t" + f"Time {batch_time}\t" + f"Loss {loss_meter}\t" + f"Acc@1 {acc1_meter}\t" + f"Acc@5 {acc5_meter}\t" + f"Mem {memory_used:.0f}MB" + ) + + logging.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}") + + if tb_writer is not None: + tb_writer.add_scalar("train/valid_loss", loss_meter.avg, params.batch_idx_train) + tb_writer.add_scalar("train/valid_acc1", acc1_meter.avg, params.batch_idx_train) + tb_writer.add_scalar("train/valid_acc5", acc5_meter.avg, params.batch_idx_train) + + if loss_meter.avg < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_meter.avg + + if acc1_meter.avg > params.best_accuracy: + params.best_accuracy = acc1_meter.avg + logging.info(f"Best accuracy: {params.best_accuracy:.2f}%") + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + mixup_fn: Optional[Mixup] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + if params.mixup > 0.0: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif params.label_smoothing > 0.0: + criterion = LabelSmoothingCrossEntropy(smoothing=params.label_smoothing) + else: + criterion = torch.nn.CrossEntropyLoss() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=0, + ) + + batch_time = AverageMeter() + loss_meter = AverageMeter() + + num_steps = len(train_dl) + + start = time.time() + end = time.time() + for batch_idx, (images, targets) in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + + images = images.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + if mixup_fn is not None: + images, targets = mixup_fn(images, targets) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + # compute outputs + outputs = model(images) + # measure accuracy and record loss + loss = criterion(outputs, targets) + + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + torch.cuda.synchronize() + + # summary stats + loss_meter.update(loss.item(), targets.size(0)) + batch_time.update(time.time() - end) + end = time.time() + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model(params=params, model_cur=model, model_avg=model_avg) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}/{num_steps}, " + f"time {batch_time}, " + f"loss {loss_meter}, " + f"batch size {targets.size(0)}, " + f"lr: {cur_lr:.2e}, " + f"mem {memory_used:.0f}MB, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/current_loss", loss_meter.val, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/averaged_loss", loss_meter.avg, params.batch_idx_train + ) + + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + epoch_time = time.time() - start + logging.info( + f"Epoch {params.cur_epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" + ) + + if loss_meter.avg < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = loss_meter.avg + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_model(params): + model = SwinTransformer( + img_size=params.img_size, + patch_size=params.patch_size, + in_chans=params.in_chans, + num_classes=params.num_classes, + embed_dim=params.embed_dim, + depths=_to_int_tuple(params.depths), + num_heads=_to_int_tuple(params.num_heads), + window_size=params.window_size, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + qk_scale=params.qk_scale, + drop_rate=params.drop_rate, + drop_path_rate=params.drop_path_rate, + ape=params.ape, + patch_norm=params.patch_norm, + fused_window_process=params.fused_window_process, + ) + return model + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed, rank) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is currently unavailable.") + device = torch.device("cuda", rank) + + logging.info(f"Device: {device}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + # Create datasets and dataloaders + imagenet = ImageNetClsDataModule(params) + train_dl, mixup_fn = imagenet.build_train_loader( + num_classes=params.num_classes, label_smoothing=params.label_smoothing + ) + valid_dl = imagenet.build_val_loader() + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1, rank) + if world_size > 1: + # For DistributedSampler + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + mixup_fn=mixup_fn, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + validate( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + tb_writer=tb_writer, + ) + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + ImageNetClsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/imagenet/CLS/swin_transformer/utils.py b/egs/imagenet/CLS/swin_transformer/utils.py new file mode 100644 index 000000000..4ada6d3f9 --- /dev/null +++ b/egs/imagenet/CLS/swin_transformer/utils.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import numpy as np + +import torch +import torch.distributed as dist + + +# We might need to move this file to icefall/utils.py in the future + + +# Copied from https://github.com/microsoft/Swin-Transformer/blob/main/utils.py +def reduce_tensor(tensor): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= dist.get_world_size() + return rt + + +def fix_random_seed(random_seed: int, rank: int): + seed = random_seed + rank + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + +# Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/metrics.py +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self) -> str: + return f"{self.val:.4f} (avg: {self.avg:.4f})" + + +# Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/metrics.py +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = min(max(topk), output.size()[1]) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [ + correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size + for k in topk + ] diff --git a/egs/imagenet/CLS/swin_transformer/validate.py b/egs/imagenet/CLS/swin_transformer/validate.py new file mode 100755 index 000000000..3f86d7d85 --- /dev/null +++ b/egs/imagenet/CLS/swin_transformer/validate.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import time +from pathlib import Path + +import torch +import torch.nn as nn + +from cls_datamodule import ImageNetClsDataModule +from train import add_model_arguments, get_params, get_model +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + str2bool, +) +from utils import AverageMeter, accuracy + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + add_model_arguments(parser) + + return parser + + +def validate( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, +) -> None: + """Run the validation process.""" + batch_time = AverageMeter() + acc1_meter = AverageMeter() + acc5_meter = AverageMeter() + + end = time.time() + for batch_idx, (images, targets) in enumerate(valid_dl): + images = images.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + # compute outputs + outputs = model(images) + + # measure accuracy and record loss + acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) + + acc1_meter.update(acc1.item(), targets.size(0)) + acc5_meter.update(acc5.item(), targets.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + logging.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}") + + +@torch.no_grad() +def main(): + parser = get_parser() + ImageNetClsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.exp_dir}/log-decode-{params.suffix}") + logging.info("Validation started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # Create datasets and dataloaders + imagenet = ImageNetClsDataModule(params) + valid_dl = imagenet.build_val_loader() + + validate( + params=params, + model=model, + valid_dl=valid_dl, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main()