diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8b6f67d3f..203861fc8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -519,7 +519,7 @@ class ScaledAdam(BatchedOptimizer): self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor ): """ - Show information of parameter wihch dominating tot_sumsq. + Show information of parameter which dominates tot_sumsq. Args: tuples: a list of tuples of (param, state, param_names) @@ -678,12 +678,17 @@ class ScaledAdam(BatchedOptimizer): ) is_too_small = param_rms < param_min_rms - is_too_large = param_rms > param_max_rms # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) - # when it gets too large, stop it from getting any larger. - scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, + (param_max_rms - param_rms) / param_rms) + delta = state["delta"] # the factor of (1-beta1) relates to momentum. delta.add_(p * scale_step, alpha=(1 - beta1)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 6c8ff1813..2454a87a2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -59,6 +59,8 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from zipformer import Zipformer +from scaling import ScheduledFloat from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -70,7 +72,6 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints @@ -79,81 +80,125 @@ from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) +from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: +def get_adjusted_batch_count( + params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return (params.batch_idx_train * params.ref_duration / + (params.max_duration * params.world_size)) + + + +def set_batch_count( + model: Union[nn.Module, DDP], batch_count: float +) -> None: if isinstance(model, DDP): # get underlying nn.Module model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): + for name, module in model.named_modules(): + if hasattr(module, 'batch_count'): module.batch_count = batch_count + if hasattr(module, 'name'): + module.name = name def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", + default="4,4,4,4,4,4", + help="Number of zipformer encoder layers per stack, comma separated.", ) - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) parser.add_argument( - "--nhead", + "--downsampling-factor", type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", + default="1,2,4,8,4,2", help="Downsampling factor for each stack of encoder layers.", ) + parser.add_argument( - "--cnn-module-kernels", + "--feedforward-dim", type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", + default="1792,1792,2304,2304,2304,1792", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="8,8,8,16,8,8", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--attention-share-layers", + type=str, + default="2", + help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="384", + help="Embedding dimension in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension" + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim." + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", ) parser.add_argument( @@ -244,7 +289,10 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." + "--base-lr", + type=float, + default=0.05, + help="The base learning rate." ) parser.add_argument( @@ -263,11 +311,21 @@ def get_parser(): """, ) + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model" + ) + + parser.add_argument( "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -290,7 +348,8 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( @@ -327,7 +386,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=2000, + default=4000, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename @@ -442,21 +501,24 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Zipformer and Transformer def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - + return tuple(map(int, s.split(','))) encoder = Zipformer( num_features=params.feature_dim, output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + downsampling_factor=to_int_tuple(params.downsampling_factor), num_encoder_layers=to_int_tuple(params.num_encoder_layers), + encoder_dim=to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=to_int_tuple(params.query_head_dim), + pos_head_dim=to_int_tuple(params.pos_head_dim), + value_head_dim=to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=to_int_tuple(params.num_heads), + attention_share_layers=to_int_tuple(params.attention_share_layers), + feedforward_dim=to_int_tuple(params.feedforward_dim), + cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, ) return encoder @@ -473,7 +535,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), + encoder_dim=int(params.encoder_dim.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -490,7 +552,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), + encoder_dim=int(params.encoder_dim.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -625,7 +687,7 @@ def compute_loss( is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute transducer loss given the model and its inputs. + Compute CTC loss given the model and its inputs. Args: params: @@ -642,7 +704,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -672,24 +738,27 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s - if batch_idx_train >= warm_step + s if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step + 1.0 if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + loss = ( + simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -784,7 +853,22 @@ def train_one_epoch( cur_batch_idx = params.get("cur_batch_idx", 0) + saved_bad_model = False + def save_bad_model(suffix: str = ""): + save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) if batch_idx < cur_batch_idx: continue cur_batch_idx = batch_idx @@ -807,13 +891,13 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() optimizer.zero_grad() except: # noqa + save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise @@ -860,14 +944,17 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) + save_bad_model() + raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] @@ -877,8 +964,8 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -889,14 +976,16 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, + "train/grad_scale", cur_grad_scale, params.batch_idx_train ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( @@ -908,9 +997,7 @@ def train_one_epoch( ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) + logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -937,8 +1024,6 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) - if params.full_libri is False: - params.valid_interval = 1600 fix_random_seed(params.seed) if world_size > 1: @@ -986,7 +1071,8 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) + model = DDP(model, device_ids=[rank], + find_unused_parameters=True) optimizer = ScaledAdam( model.named_parameters(), @@ -1010,7 +1096,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1087,7 +1173,8 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, + init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1208,9 +1295,7 @@ def scan_pessimistic_batches_for_oom( ) display_and_save_batch(batch, params=params, sp=sp) raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) + logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") def main():