From 096ebeaf23b8d551cf29668dd700064896103faf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 5 Jan 2023 12:01:42 +0800 Subject: [PATCH] take a couple files from liyong's branch --- .../ASR/pruned_transducer_stateless7/optim.py | 162 +++++++-- .../ASR/pruned_transducer_stateless7/train.py | 310 +++++++----------- 2 files changed, 263 insertions(+), 209 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 2a7c3d89a..8b6f67d3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -18,7 +18,7 @@ import contextlib import logging import random from collections import defaultdict -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from lhotse.utils import fix_random_seed @@ -132,6 +132,9 @@ class ScaledAdam(BatchedOptimizer): Args: params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. lr: The learning rate. We will typically use a learning rate schedule that starts at 0.03 and decreases over time, i.e. much higher than other common optimizers. @@ -178,15 +181,8 @@ class ScaledAdam(BatchedOptimizer): scalar_max=10.0, size_update_period=4, clipping_update_period=100, - parameters_names=None, - show_dominant_parameters=True, ): - assert parameters_names is not None, ( - "Please prepare parameters_names," - "which is a List[List[str]]. Each List[str] is for a group" - "and each str is for a parameter" - ) defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -200,10 +196,135 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period=clipping_update_period, ) + # If params only contains parameters or group of parameters, + # i.e when parameter names are not given, + # this flag will be set to False in funciton _get_names_of_parameters. + self.show_dominant_parameters = True + params, parameters_names = self._get_names_of_parameters(params) super(ScaledAdam, self).__init__(params, defaults) assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names - self.show_dominant_parameters = show_dominant_parameters + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way ScaledAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'params': model.encoder.named_parameters(), 'lr': 0.05}, + {'params': model.decoder.named_parameters(), 'lr': 0.01}, + {'params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # cur is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_cur_group = [] + param_names_cur_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for p_or_np_cur_group in iterable_or_groups: + param_iterable_cur_group = [] + param_names_cur_group = [] + p_or_np_iterable = p_or_np_cur_group["params"] + for p_or_np in p_or_np_iterable: + if isinstance(p_or_np, tuple): + # case 4 + name, param = p_or_np + else: + # case 2 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + + # The original `params` filed contains named_parameters. + # After following assignment, + # it will be changed to an iterable of parameter, + # and other fileds, if exist, are still original values. + # So param_groups could be used to initialize + # an underlying torch.Optimizer later. + p_or_np_cur_group["params"] = param_iterable_cur_group + param_groups.append(p_or_np_cur_group) + param_groups_names.append(param_names_cur_group) + return param_groups, param_groups_names def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) @@ -398,7 +519,7 @@ class ScaledAdam(BatchedOptimizer): self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor ): """ - Show information of parameter wihch dominanting tot_sumsq. + Show information of parameter wihch dominating tot_sumsq. Args: tuples: a list of tuples of (param, state, param_names) @@ -416,7 +537,7 @@ class ScaledAdam(BatchedOptimizer): batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars batch_sumsq_orig = batch_grad**2 - # Dummpy values used by following `zip` statement. + # Dummy values used by following `zip` statement. batch_rms_orig = torch.ones(p.shape[0]) else: batch_rms_orig = state["param_rms"] @@ -449,11 +570,11 @@ class ScaledAdam(BatchedOptimizer): dominant_grad, ) = sorted_by_proportion[dominant_param_name] logging.info( - f"Parameter Dominanting tot_sumsq {dominant_param_name}" + f"Parameter dominating tot_sumsq {dominant_param_name}" f" with proportion {dominant_proportion:.2f}," f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" f"={dominant_sumsq:.3e}," - f" grad_sumsq = {(dominant_grad**2).sum():.3e}," + f" grad_sumsq={(dominant_grad**2).sum():.3e}," f" orig_rms_sq={(dominant_rms**2).item():.3e}" ) @@ -561,11 +682,8 @@ class ScaledAdam(BatchedOptimizer): # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) - - # and ensure the parameter rms after update never exceeds param_max_rms. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) - + # when it gets too large, stop it from getting any larger. + scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. delta.add_(p * scale_step, alpha=(1 - beta1)) @@ -779,9 +897,7 @@ class Eden(LRScheduler): def _test_eden(): m = torch.nn.Linear(100, 100) - parameters_names = [ [ x[0] for x in m.named_parameters() ] ] - optim = ScaledAdam(m.parameters(), lr=0.03, - parameters_names=parameters_names) + optim = ScaledAdam(m.parameters(), lr=0.03) scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) @@ -992,9 +1108,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - parameters_names = [ [ x[0] for x in m.named_parameters() ] ] - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0, - parameters_names=parameters_names) + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 72cec1328..6c8ff1813 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -59,8 +59,6 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from zipformer import Zipformer -from scaling import ScheduledFloat from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -72,6 +70,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints @@ -80,125 +79,81 @@ from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) -from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -def get_adjusted_batch_count( - params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return (params.batch_idx_train * params.ref_duration / - (params.max_duration * params.world_size)) - - - -def set_batch_count( - model: Union[nn.Module, DDP], batch_count: float -) -> None: +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): # get underlying nn.Module model = model.module - for name, module in model.named_modules(): - if hasattr(module, 'batch_count'): + for module in model.modules(): + if hasattr(module, "batch_count"): module.batch_count = batch_count - if hasattr(module, 'name'): - module.name = name def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="4,4,4,4,4,4", - help="Number of zipformer encoder layers per stack, comma separated.", + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", ) + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) parser.add_argument( - "--downsampling-factor", + "--nhead", type=str, - default="1,2,4,8,4,2", + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", help="Downsampling factor for each stack of encoder layers.", ) - parser.add_argument( - "--feedforward-dim", + "--cnn-module-kernels", type=str, - default="1792,1792,2304,2304,2304,1792", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="8,8,8,16,8,8", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--attention-share-layers", - type=str, - default="2", - help="Number of layers that share attention weights within each zipformer stack: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="384", - help="Embedding dimension in encoder stacks: a single int or comma-separated list." - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list." - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list." - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list." - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension" - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim." - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", ) parser.add_argument( @@ -289,10 +244,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", - type=float, - default=0.05, - help="The base learning rate." + "--base-lr", type=float, default=0.05, help="The base learning rate." ) parser.add_argument( @@ -311,21 +263,11 @@ def get_parser(): """, ) - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model" - ) - - parser.add_argument( "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -348,8 +290,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -386,7 +327,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=4000, + default=2000, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename @@ -501,24 +442,21 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Zipformer and Transformer def to_int_tuple(s: str): - return tuple(map(int, s.split(','))) + return tuple(map(int, s.split(","))) + encoder = Zipformer( num_features=params.feature_dim, output_downsampling_factor=2, - downsampling_factor=to_int_tuple(params.downsampling_factor), + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), num_encoder_layers=to_int_tuple(params.num_encoder_layers), - encoder_dim=to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=to_int_tuple(params.query_head_dim), - pos_head_dim=to_int_tuple(params.pos_head_dim), - value_head_dim=to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=to_int_tuple(params.num_heads), - attention_share_layers=to_int_tuple(params.attention_share_layers), - feedforward_dim=to_int_tuple(params.feedforward_dim), - cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, ) return encoder @@ -535,7 +473,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(params.encoder_dim.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -552,7 +490,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(params.encoder_dim.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -687,7 +625,7 @@ def compute_loss( is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute transducer loss given the model and its inputs. Args: params: @@ -704,11 +642,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -738,27 +672,24 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = ( - simple_loss_scale * simple_loss + - pruned_loss_scale * pruned_loss - ) + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -853,22 +784,7 @@ def train_one_epoch( cur_batch_idx = params.get("cur_batch_idx", 0) - saved_bad_model = False - def save_bad_model(suffix: str = ""): - save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) if batch_idx < cur_batch_idx: continue cur_batch_idx = batch_idx @@ -891,13 +807,13 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) scaler.update() optimizer.zero_grad() except: # noqa - save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise @@ -944,17 +860,14 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] @@ -964,8 +877,8 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " + - (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -976,16 +889,14 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( @@ -997,7 +908,9 @@ def train_one_epoch( ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -1024,6 +937,8 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 fix_random_seed(params.seed) if world_size > 1: @@ -1071,14 +986,12 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], - find_unused_parameters=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = ScaledAdam( - model.parameters(), + model.named_parameters(), lr=params.base_lr, clipping_scale=2.0, - parameters_names=[ [p[0] for p in model.named_parameters()] ], ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1097,7 +1010,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1120,7 +1033,33 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -1148,8 +1087,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, - init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1270,7 +1208,9 @@ def scan_pessimistic_batches_for_oom( ) display_and_save_batch(batch, params=params, sp=sp) raise - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) def main():