#!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo, # Zengwei Yao, # Daniel Povey) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import copy import logging import datetime import time from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Union import optim import torch import torch.multiprocessing as mp from cls_datamodule import ImageNetClsDataModule from optim import Eden, ScaledAdam from utils import AverageMeter, accuracy, fix_random_seed, reduce_tensor from timm.data import Mixup from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from torch import nn from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from icefall import diagnostics from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import update_averaged_model from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, str2bool, get_parameter_groups_with_lrs, ) from swin_transformer import SwinTransformer LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_adjusted_batch_count(params: AttributeDict) -> float: # Returns the number of batches we would have used so far. # This is for purposes of set_batch_count(). return params.batch_idx_train * params.world_size def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): # get underlying nn.Module model = model.module for name, module in model.named_modules(): if hasattr(module, "batch_count"): module.batch_count = batch_count if hasattr(module, "name"): module.name = name def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--patch-size", type=int, default=4, help="Patch size. Default: 4", ) parser.add_argument( "--embed-dim", type=int, default=96, help="Patch embedding dimension. Default: 96", ) parser.add_argument( "--depths", type=str, default="2,2,6,2", help="Depth of each Swin Transformer layer.", ) parser.add_argument( "--num-heads", type=str, default="3,6,12,24", help="Number of attention heads in different layers.", ) parser.add_argument( "--window-size", type=int, default=7, help="Window size. Default: 7", ) parser.add_argument( "--mlp-ratio", type=float, default=4.0, help="Ratio of mlp hidden dim to embedding dim. Default: 4", ) parser.add_argument( "--qkv-bias", type=str2bool, default=True, help="If True, add a learnable bias to query, key, value. Default: True", ) parser.add_argument( "--qk-scale", type=float, default=None, help="Override default qk scale of head_dim ** -0.5 if set. Default: None", ) parser.add_argument( "--ape", type=str2bool, default=False, help="If True, add absolute position embedding to the patch embedding. Default: False", ) parser.add_argument( "--patch-norm", type=str2bool, default=True, help="If True, add normalization after patch embedding. Default: True", ) parser.add_argument( "--drop-rate", type=float, default=0.0, help="Dropout rate", ) parser.add_argument( "--drop-path-rate", type=float, default=0.1, help="Drop path rate", ) parser.add_argument( "--fused-window-process", type=str2bool, default=False, help="If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False", ) def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( "--master-port", type=int, default=12354, help="Master port to use for DDP training.", ) parser.add_argument( "--tensorboard", type=str2bool, default=True, help="Should various information be logged in tensorboard.", ) parser.add_argument( "--num-epochs", type=int, default=30, help="Number of epochs to train.", ) parser.add_argument( "--start-epoch", type=int, default=1, help="""Resume training from this epoch. It should be positive. If larger than 1, it will load checkpoint from exp-dir/epoch-{start_epoch-1}.pt """, ) parser.add_argument( "--exp-dir", type=str, default="swin_transformer/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, ) parser.add_argument( "--bpe-model", type=str, default="data/lang_bpe_500/bpe.model", help="Path to the BPE model", ) parser.add_argument( "--base-lr", type=float, default=0.025, help="The base learning rate." ) parser.add_argument( "--lr-batches", type=float, default=7500, help="""Number of steps that affects how rapidly the learning rate decreases. We suggest not to change this.""", ) parser.add_argument( "--lr-epochs", type=float, default=3.5, help="""Number of epochs that affects how rapidly the learning rate decreases. """, ) parser.add_argument( "--seed", type=int, default=42, help="The seed for random generators intended for reproducibility", ) parser.add_argument( "--label-smoothing", type=float, default=0.1, help="Label smoothing used in loss computation", ) parser.add_argument( "--print-diagnostics", type=str2bool, default=False, help="Accumulate stats on activations, print them and exit.", ) parser.add_argument( "--inf-check", type=str2bool, default=False, help="Add hooks to check for infinite module outputs and gradients.", ) parser.add_argument( "--average-period", type=int, default=200, help="""Update the averaged model, namely `model_avg`, after processing this number of batches. `model_avg` is a separate version of model, in which each floating-point parameter is the average of all the parameters from the start of training. Each time we take the average, we do: `model_avg = model * (average_period / batch_idx_train) + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. """, ) parser.add_argument( "--use-fp16", type=str2bool, default=False, help="Whether to use half precision training.", ) add_model_arguments(parser) return parser def get_params() -> AttributeDict: """Return a dict containing training parameters. All training related parameters that are not passed from the commandline are saved in the variable `params`. Commandline options are merged into `params` after they are parsed, so you can also access them via `params`. Explanation of options saved in `params`: - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. - best_valid_loss: Best validation loss so far. It is used to select the model that has the lowest validation loss. It is updated during the training. - best_train_epoch: It is the epoch that has the best training loss. - best_valid_epoch: It is the epoch that has the best validation loss. - batch_idx_train: Used to writing statistics to tensorboard. It contains number of batches trained so far across epochs. - log_interval: Print training loss if batch_idx % log_interval` is 0 - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - valid_interval: Run validation if batch_idx % valid_interval is 0 - feature_dim: The model input dim. It has to match the one used in computing features. - subsampling_factor: The subsampling factor for the model. - encoder_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. - warm_step: The warmup period that dictates the decay of the scale on "simple" (un-pruned) loss. """ params = AttributeDict( { "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_accuracy": 0.0, # acc1 "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 "valid_log_interval": 10, # parameters for SwinTransformer "img_size": 224, "in_chans": 3, "num_classes": 1000, "env_info": get_env_info(), } ) return params def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. If params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, `best_valid_loss`, and `best_accuracy` in `params`. Args: params: The return value of :func:`get_params`. model: The training model. model_avg: The stored model averaged from the start of training. optimizer: The optimizer that we are using. scheduler: The scheduler that we are using. Returns: Return a dict containing previously saved training info. """ if params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None assert filename.is_file(), f"{filename} does not exist!" saved_params = load_checkpoint( filename, model=model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, ) keys = [ "best_train_epoch", "best_valid_epoch", "batch_idx_train", "best_train_loss", "best_valid_loss", "best_accuracy", ] for k in keys: params[k] = saved_params[k] return saved_params def save_checkpoint( params: AttributeDict, model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. Args: params: It is returned by :func:`get_params`. model: The training model. model_avg: The stored model averaged from the start of training. optimizer: The optimizer used in the training. scaler: The scaler used for mix precision training. """ if rank != 0: return filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" save_checkpoint_impl( filename=filename, model=model, model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, scaler=scaler, rank=rank, ) if params.best_train_epoch == params.cur_epoch: best_train_filename = params.exp_dir / "best-train-loss.pt" copyfile(src=filename, dst=best_train_filename) if params.best_valid_epoch == params.cur_epoch: best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) @torch.no_grad() def validate( params: AttributeDict, model: Union[nn.Module, DDP], valid_dl: torch.utils.data.DataLoader, world_size: int = 1, tb_writer: Optional[SummaryWriter] = None, ) -> None: """Run the validation process.""" model.eval() criterion = torch.nn.CrossEntropyLoss() batch_time = AverageMeter() loss_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() end = time.time() for batch_idx, (images, targets) in enumerate(valid_dl): images = images.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) # compute outputs outputs = model(images) # measure accuracy and record loss loss = criterion(outputs, targets) acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) if world_size > 1: acc1 = reduce_tensor(acc1) acc5 = reduce_tensor(acc5) loss = reduce_tensor(loss) loss_meter.update(loss.item(), targets.size(0)) acc1_meter.update(acc1.item(), targets.size(0)) acc5_meter.update(acc5.item(), targets.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % params.valid_log_interval == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logging.info( f"Test: [{batch_idx}/{len(valid_dl)}]\t" f"Time {batch_time}\t" f"Loss {loss_meter}\t" f"Acc@1 {acc1_meter}\t" f"Acc@5 {acc5_meter}\t" f"Mem {memory_used:.0f}MB" ) logging.info(f" * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}") if tb_writer is not None: tb_writer.add_scalar("train/valid_loss", loss_meter.avg, params.batch_idx_train) tb_writer.add_scalar("train/valid_acc1", acc1_meter.avg, params.batch_idx_train) tb_writer.add_scalar("train/valid_acc5", acc5_meter.avg, params.batch_idx_train) if loss_meter.avg < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_meter.avg if acc1_meter.avg > params.best_accuracy: params.best_accuracy = acc1_meter.avg logging.info(f"Best accuracy: {params.best_accuracy:.2f}%") def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, train_dl: torch.utils.data.DataLoader, scaler: GradScaler, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, mixup_fn: Optional[Mixup] = None, world_size: int = 1, rank: int = 0, ) -> None: """Train the model for one epoch. The training loss from the mean of all frames is saved in `params.train_loss`. It runs the validation process every `params.valid_interval` batches. Args: params: It is returned by :func:`get_params`. model: The model for training. optimizer: The optimizer we are using. scheduler: The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. scaler: The scaler used for mix precision training. model_avg: The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. rank: The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ model.train() if params.mixup > 0.0: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif params.label_smoothing > 0.0: criterion = LabelSmoothingCrossEntropy(smoothing=params.label_smoothing) else: criterion = torch.nn.CrossEntropyLoss() saved_bad_model = False def save_bad_model(suffix: str = ""): save_checkpoint_impl( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", model=model, model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, scaler=scaler, rank=0, ) batch_time = AverageMeter() loss_meter = AverageMeter() num_steps = len(train_dl) start = time.time() end = time.time() for batch_idx, (images, targets) in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) params.batch_idx_train += 1 images = images.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: images, targets = mixup_fn(images, targets) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): # compute outputs outputs = model(images) # measure accuracy and record loss loss = criterion(outputs, targets) scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() optimizer.zero_grad() torch.cuda.synchronize() # summary stats loss_meter.update(loss.item(), targets.size(0)) batch_time.update(time.time() - end) end = time.time() except: # noqa save_bad_model() raise if params.print_diagnostics and batch_idx == 5: return if ( rank == 0 and params.batch_idx_train > 0 and params.batch_idx_train % params.average_period == 0 ): update_averaged_model(params=params, model_cur=model, model_avg=model_avg) if batch_idx % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: save_bad_model(suffix="-first-warning") saved_bad_model = True logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: save_bad_model() raise RuntimeError( f"grad_scale is too small, exiting: {cur_grad_scale}" ) if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logging.info( f"Epoch {params.cur_epoch}, batch {batch_idx}/{num_steps}, " f"time {batch_time}, " f"loss {loss_meter}, " f"batch size {targets.size(0)}, " f"lr: {cur_lr:.2e}, " f"mem {memory_used:.0f}MB, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train ) tb_writer.add_scalar( "train/current_loss", loss_meter.val, params.batch_idx_train ) tb_writer.add_scalar( "train/averaged_loss", loss_meter.avg, params.batch_idx_train ) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) epoch_time = time.time() - start logging.info( f"Epoch {params.cur_epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" ) if loss_meter.avg < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = loss_meter.avg def _to_int_tuple(s: str): return tuple(map(int, s.split(","))) def get_model(params): model = SwinTransformer( img_size=params.img_size, patch_size=params.patch_size, in_chans=params.in_chans, num_classes=params.num_classes, embed_dim=params.embed_dim, depths=_to_int_tuple(params.depths), num_heads=_to_int_tuple(params.num_heads), window_size=params.window_size, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, qk_scale=params.qk_scale, drop_rate=params.drop_rate, drop_path_rate=params.drop_path_rate, ape=params.ape, patch_norm=params.patch_norm, fused_window_process=params.fused_window_process, ) return model def run(rank, world_size, args): """ Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() """ params = get_params() params.update(vars(args)) fix_random_seed(params.seed, rank) if world_size > 1: setup_dist(rank, world_size, params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None if not torch.cuda.is_available(): raise RuntimeError("CUDA is currently unavailable.") device = torch.device("cuda", rank) logging.info(f"Device: {device}") logging.info(params) logging.info("About to create model") model = get_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") model_avg: Optional[nn.Module] = None if rank == 0: # model_avg is only used with rank 0 model_avg = copy.deepcopy(model).to(torch.float64) assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) model.to(device) if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = ScaledAdam( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), lr=params.base_lr, # should have no effect clipping_scale=2.0, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) if ( checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None ): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) if params.inf_check: register_inf_check_hooks(model) # Create datasets and dataloaders imagenet = ImageNetClsDataModule(params) train_dl, mixup_fn = imagenet.build_train_loader( num_classes=params.num_classes, label_smoothing=params.label_smoothing ) valid_dl = imagenet.build_val_loader() scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1, rank) if world_size > 1: # For DistributedSampler train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) params.cur_epoch = epoch train_one_epoch( params=params, model=model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, train_dl=train_dl, mixup_fn=mixup_fn, scaler=scaler, tb_writer=tb_writer, world_size=world_size, rank=rank, ) if params.print_diagnostics: diagnostic.print_diagnostics() break validate( params=params, model=model, valid_dl=valid_dl, world_size=world_size, tb_writer=tb_writer, ) save_checkpoint( params=params, model=model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, scaler=scaler, rank=rank, ) logging.info("Done!") if world_size > 1: torch.distributed.barrier() cleanup_dist() def main(): parser = get_parser() ImageNetClsDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) world_size = args.world_size assert world_size >= 1 if world_size > 1: mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) else: run(rank=0, world_size=1, args=args) torch.set_num_threads(1) torch.set_num_interop_threads(1) if __name__ == "__main__": main()