Merge 5beef022852e1dea957530b7f27d467eb7373415 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Zengwei Yao 2025-08-17 06:14:31 +00:00 committed by GitHub
commit 2cd098d344
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 2138 additions and 0 deletions

View File

@ -0,0 +1,317 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This is modified from https://github.com/microsoft/Swin-Transformer/blob/main/data/build.py
# The default args are copied from https://github.com/microsoft/Swin-Transformer/blob/main/config.py
# We adjust the code style as other recipes in icefall.
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
from icefall.dist import get_rank, get_world_size
from icefall.utils import str2bool
from timm.data import Mixup, create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms
try:
from torchvision.transforms import InterpolationMode
def _pil_interp(method):
if method == "bicubic":
return InterpolationMode.BICUBIC
elif method == "lanczos":
return InterpolationMode.LANCZOS
elif method == "hamming":
return InterpolationMode.HAMMING
else:
# default bilinear, do we want to allow nearest?
return InterpolationMode.BILINEAR
import timm.data.transforms as timm_transforms
timm_transforms._pil_interp = _pil_interp
except: # noqa
from timm.data.transforms import _pil_interp
class ImageNetClsDataModule:
def __init__(self, args: argparse.Namespace):
self.args = args
self.rank = get_rank()
self.world_size = get_world_size()
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="Image classification data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders -- they control the effective batch sizes, "
"sampling strategies, applied data augmentations, etc.",
)
group.add_argument(
"--data-path",
type=Path,
default=Path("imagenet"),
help="Path to imagenet dataset,",
)
group.add_argument(
"--batch-size",
type=int,
default=128,
help="Batch size for a single GPU, could be overwritten by command line argument",
)
group.add_argument(
"--color-jitter",
type=float,
default=0.4,
help="Color jitter factor",
)
group.add_argument(
"--auto-augment",
type=str,
default="rand-m9-mstd0.5-inc1",
help="AutoAugment policy. 'v0' or 'original'",
)
group.add_argument(
"--reprob",
type=float,
default=0.25,
help="Random erase prob",
)
group.add_argument(
"--remode",
type=str,
default="pixel",
help="Random erase mode",
)
group.add_argument(
"--recount",
type=int,
default=1,
help="Random erase count",
)
group.add_argument(
"--interpolation",
type=str,
default="bicubic",
help="Interpolation to resize image (random, bilinear, bicubic)",
)
group.add_argument(
"--crop",
type=str2bool,
default=True,
help="Whether to use center crop when testing",
)
group.add_argument(
"--mixup",
type=float,
default=0.8,
help="Mixup alpha, mixup enabled if > 0",
)
group.add_argument(
"--cutmix",
type=float,
default=1.0,
help="Cutmix alpha, cutmix enabled if > 0",
)
group.add_argument(
"--cutmix-minmax",
type=float,
default=None,
help="Cutmix min/max ratio, overrides alpha and enables cutmix if set",
)
group.add_argument(
"--mixup-prob",
type=float,
default=1.0,
help="Probability of performing mixup or cutmix when either/both is enabled",
)
group.add_argument(
"--mixup-switch-prob",
type=float,
default=0.5,
help="Probability of switching to cutmix when both mixup and cutmix enabled",
)
group.add_argument(
"--mixup-mode",
type=str,
default="batch",
help="How to apply mixup/cutmix params. Per 'batch', 'pair', or 'elem'",
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of data loading threads",
)
group.add_argument(
"--pin-memory",
type=str2bool,
default=True,
help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
)
def build_transform(self, is_training: bool = False):
resize_im = self.args.img_size > 32
if is_training:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=self.args.img_size,
is_training=True,
color_jitter=self.args.color_jitter
if self.args.color_jitter > 0
else None,
auto_augment=self.args.auto_augment
if self.args.auto_augment != "none"
else None,
re_prob=self.args.reprob,
re_mode=self.args.remode,
re_count=self.args.recount,
interpolation=self.args.interpolation,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
self.args.img_size, padding=4
)
return transform
t = []
if resize_im:
if self.args.crop:
size = int((256 / 224) * self.args.img_size)
t.append(
transforms.Resize(
size, interpolation=_pil_interp(self.args.interpolation)
),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(self.args.img_size))
else:
t.append(
transforms.Resize(
(self.args.img_size, self.args.img_size),
interpolation=_pil_interp(self.args.interpolation),
)
)
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)
def build_dataset(self, is_training: bool = False):
transform = self.build_transform(is_training)
prefix = "train" if is_training else "val"
root = os.path.join(self.args.data_path, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
return dataset
def build_train_loader(
self, num_classes: int, label_smoothing: Optional[float] = None
):
assert num_classes == 1000, num_classes
dataset_train = self.build_dataset(is_training=True)
logging.info(f"rank {self.rank} successfully build train dataset")
if self.world_size > 1:
sampler_train = DistributedSampler(
dataset_train,
num_replicas=self.world_size,
rank=self.rank,
shuffle=True,
)
else:
sampler_train = RandomSampler(dataset_train)
# TODO: need to set up worker_init_fn?
data_loader_train = DataLoader(
dataset_train,
sampler=sampler_train,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory,
drop_last=True,
)
# setup mixup / cutmix
mixup_fn = None
mixup_active = (
self.args.mixup > 0
or self.args.cutmix > 0.0
or self.args.cutmix_minmax is not None
)
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=self.args.mixup,
cutmix_alpha=self.args.cutmix,
cutmix_minmax=self.args.cutmix_minmax,
prob=self.args.mixup_prob,
switch_prob=self.args.mixup_switch_prob,
mode=self.args.mixup_mode,
label_smoothing=label_smoothing,
num_classes=num_classes,
)
return data_loader_train, mixup_fn
def build_val_loader(self):
dataset_val = self.build_dataset(is_training=False)
logging.info(f"rank {self.rank} successfully build val dataset")
if self.world_size > 1:
sampler_val = DistributedSampler(
dataset_val, num_replicas=self.world_size, rank=self.rank, shuffle=False
)
else:
sampler_val = SequentialSampler(dataset_val)
data_loader_val = DataLoader(
dataset_val,
sampler=sampler_val,
batch_size=self.args.batch_size,
shuffle=False,
num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory,
drop_last=False,
)
return data_loader_val

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1,617 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
# copied from https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
try:
import os, sys
kernel_path = os.path.abspath(os.path.join('..'))
sys.path.append(kernel_path)
from kernels.window_process.window_process import WindowProcess, WindowProcessReverse
except:
WindowProcess = None
WindowProcessReverse = None
print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.")
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
fused_window_process=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
fused_window_process=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
fused_window_process=fused_window_process)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, fused_window_process=False, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
fused_window_process=fused_window_process)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops

View File

@ -0,0 +1,924 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import logging
import datetime
import time
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
import optim
import torch
import torch.multiprocessing as mp
from cls_datamodule import ImageNetClsDataModule
from optim import Eden, ScaledAdam
from utils import AverageMeter, accuracy, fix_random_seed, reduce_tensor
from timm.data import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from torch import nn
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import update_averaged_model
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
str2bool,
get_parameter_groups_with_lrs,
)
from swin_transformer import SwinTransformer
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_adjusted_batch_count(params: AttributeDict) -> float:
# Returns the number of batches we would have used so far.
# This is for purposes of set_batch_count().
return params.batch_idx_train * params.world_size
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for name, module in model.named_modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
if hasattr(module, "name"):
module.name = name
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--patch-size",
type=int,
default=4,
help="Patch size. Default: 4",
)
parser.add_argument(
"--embed-dim",
type=int,
default=96,
help="Patch embedding dimension. Default: 96",
)
parser.add_argument(
"--depths",
type=str,
default="2,2,6,2",
help="Depth of each Swin Transformer layer.",
)
parser.add_argument(
"--num-heads",
type=str,
default="3,6,12,24",
help="Number of attention heads in different layers.",
)
parser.add_argument(
"--window-size",
type=int,
default=7,
help="Window size. Default: 7",
)
parser.add_argument(
"--mlp-ratio",
type=float,
default=4.0,
help="Ratio of mlp hidden dim to embedding dim. Default: 4",
)
parser.add_argument(
"--qkv-bias",
type=str2bool,
default=True,
help="If True, add a learnable bias to query, key, value. Default: True",
)
parser.add_argument(
"--qk-scale",
type=float,
default=None,
help="Override default qk scale of head_dim ** -0.5 if set. Default: None",
)
parser.add_argument(
"--ape",
type=str2bool,
default=False,
help="If True, add absolute position embedding to the patch embedding. Default: False",
)
parser.add_argument(
"--patch-norm",
type=str2bool,
default=True,
help="If True, add normalization after patch embedding. Default: True",
)
parser.add_argument(
"--drop-rate",
type=float,
default=0.0,
help="Dropout rate",
)
parser.add_argument(
"--drop-path-rate",
type=float,
default=0.1,
help="Drop path rate",
)
parser.add_argument(
"--fused-window-process",
type=str2bool,
default=False,
help="If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="swin_transformer/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--base-lr", type=float, default=0.025, help="The base learning rate."
)
parser.add_argument(
"--lr-batches",
type=float,
default=7500,
help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""",
)
parser.add_argument(
"--lr-epochs",
type=float,
default=3.5,
help="""Number of epochs that affects how rapidly the learning rate decreases.
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--label-smoothing",
type=float,
default=0.1,
help="Label smoothing used in loss computation",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--inf-check",
type=str2bool,
default=False,
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument(
"--average-period",
type=int,
default=200,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
add_model_arguments(parser)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_accuracy": 0.0, # acc1
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
"valid_log_interval": 10,
# parameters for SwinTransformer
"img_size": 224,
"in_chans": 3,
"num_classes": 1000,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
model_avg: nn.Module = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`, `best_valid_loss`,
and `best_accuracy` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
model_avg:
The stored model averaged from the start of training.
optimizer:
The optimizer that we are using.
scheduler:
The scheduler that we are using.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(
filename,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
"best_accuracy",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
model_avg:
The stored model averaged from the start of training.
optimizer:
The optimizer used in the training.
scaler:
The scaler used for mix precision training.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
@torch.no_grad()
def validate(
params: AttributeDict,
model: Union[nn.Module, DDP],
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
tb_writer: Optional[SummaryWriter] = None,
) -> None:
"""Run the validation process."""
model.eval()
criterion = torch.nn.CrossEntropyLoss()
batch_time = AverageMeter()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
end = time.time()
for batch_idx, (images, targets) in enumerate(valid_dl):
images = images.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
# compute outputs
outputs = model(images)
# measure accuracy and record loss
loss = criterion(outputs, targets)
acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
if world_size > 1:
acc1 = reduce_tensor(acc1)
acc5 = reduce_tensor(acc5)
loss = reduce_tensor(loss)
loss_meter.update(loss.item(), targets.size(0))
acc1_meter.update(acc1.item(), targets.size(0))
acc5_meter.update(acc5.item(), targets.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % params.valid_log_interval == 0:
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
logging.info(
f"Test: [{batch_idx}/{len(valid_dl)}]\t"
f"Time {batch_time}\t"
f"Loss {loss_meter}\t"
f"Acc@1 {acc1_meter}\t"
f"Acc@5 {acc5_meter}\t"
f"Mem {memory_used:.0f}MB"
)
logging.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}")
if tb_writer is not None:
tb_writer.add_scalar("train/valid_loss", loss_meter.avg, params.batch_idx_train)
tb_writer.add_scalar("train/valid_acc1", acc1_meter.avg, params.batch_idx_train)
tb_writer.add_scalar("train/valid_acc5", acc5_meter.avg, params.batch_idx_train)
if loss_meter.avg < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_meter.avg
if acc1_meter.avg > params.best_accuracy:
params.best_accuracy = acc1_meter.avg
logging.info(f"Best accuracy: {params.best_accuracy:.2f}%")
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
mixup_fn: Optional[Mixup] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
if params.mixup > 0.0:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif params.label_smoothing > 0.0:
criterion = LabelSmoothingCrossEntropy(smoothing=params.label_smoothing)
else:
criterion = torch.nn.CrossEntropyLoss()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
rank=0,
)
batch_time = AverageMeter()
loss_meter = AverageMeter()
num_steps = len(train_dl)
start = time.time()
end = time.time()
for batch_idx, (images, targets) in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
params.batch_idx_train += 1
images = images.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
if mixup_fn is not None:
images, targets = mixup_fn(images, targets)
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
# compute outputs
outputs = model(images)
# measure accuracy and record loss
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
torch.cuda.synchronize()
# summary stats
loss_meter.update(loss.item(), targets.size(0))
batch_time.update(time.time() - end)
end = time.time()
except: # noqa
save_bad_model()
raise
if params.print_diagnostics and batch_idx == 5:
return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(params=params, model_cur=model, model_avg=model_avg)
if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}/{num_steps}, "
f"time {batch_time}, "
f"loss {loss_meter}, "
f"batch size {targets.size(0)}, "
f"lr: {cur_lr:.2e}, "
f"mem {memory_used:.0f}MB, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar(
"train/current_loss", loss_meter.val, params.batch_idx_train
)
tb_writer.add_scalar(
"train/averaged_loss", loss_meter.avg, params.batch_idx_train
)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
epoch_time = time.time() - start
logging.info(
f"Epoch {params.cur_epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
)
if loss_meter.avg < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = loss_meter.avg
def _to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
def get_model(params):
model = SwinTransformer(
img_size=params.img_size,
patch_size=params.patch_size,
in_chans=params.in_chans,
num_classes=params.num_classes,
embed_dim=params.embed_dim,
depths=_to_int_tuple(params.depths),
num_heads=_to_int_tuple(params.num_heads),
window_size=params.window_size,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
qk_scale=params.qk_scale,
drop_rate=params.drop_rate,
drop_path_rate=params.drop_path_rate,
ape=params.ape,
patch_norm=params.patch_norm,
fused_window_process=params.fused_window_process,
)
return model
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed, rank)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
if not torch.cuda.is_available():
raise RuntimeError("CUDA is currently unavailable.")
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
# Create datasets and dataloaders
imagenet = ImageNetClsDataModule(params)
train_dl, mixup_fn = imagenet.build_train_loader(
num_classes=params.num_classes, label_smoothing=params.label_smoothing
)
valid_dl = imagenet.build_val_loader()
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1, rank)
if world_size > 1:
# For DistributedSampler
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
mixup_fn=mixup_fn,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
validate(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
tb_writer=tb_writer,
)
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
ImageNetClsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import torch
import torch.distributed as dist
# We might need to move this file to icefall/utils.py in the future
# Copied from https://github.com/microsoft/Swin-Transformer/blob/main/utils.py
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
def fix_random_seed(random_seed: int, rank: int):
seed = random_seed + rank
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/metrics.py
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self) -> str:
return f"{self.val:.4f} (avg: {self.avg:.4f})"
# Copied from https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/metrics.py
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
maxk = min(max(topk), output.size()[1])
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
return [
correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size
for k in topk
]

View File

@ -0,0 +1,201 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import time
from pathlib import Path
import torch
import torch.nn as nn
from cls_datamodule import ImageNetClsDataModule
from train import add_model_arguments, get_params, get_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
str2bool,
)
from utils import AverageMeter, accuracy
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
add_model_arguments(parser)
return parser
def validate(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
) -> None:
"""Run the validation process."""
batch_time = AverageMeter()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
end = time.time()
for batch_idx, (images, targets) in enumerate(valid_dl):
images = images.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
# compute outputs
outputs = model(images)
# measure accuracy and record loss
acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
acc1_meter.update(acc1.item(), targets.size(0))
acc5_meter.update(acc5.item(), targets.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
logging.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}")
@torch.no_grad()
def main():
parser = get_parser()
ImageNetClsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.exp_dir}/log-decode-{params.suffix}")
logging.info("Validation started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# Create datasets and dataloaders
imagenet = ImageNetClsDataModule(params)
valid_dl = imagenet.build_val_loader()
validate(
params=params,
model=model,
valid_dl=valid_dl,
)
logging.info("Done!")
if __name__ == "__main__":
main()