diff --git a/egs/librispeech/ASR/zipformer/train-limit-grad.py b/egs/librispeech/ASR/zipformer/train-limit-grad.py index cd74d78c0..0a8ee894c 100755 --- a/egs/librispeech/ASR/zipformer/train-limit-grad.py +++ b/egs/librispeech/ASR/zipformer/train-limit-grad.py @@ -1121,11 +1121,18 @@ def train_one_epoch( rank=0, ) + def is_grad_limit_enabled(): + return (0 < params.limit_grad_start_batch <= params.batch_idx_train) and ( + params.batch_idx_train % params.limit_grad_every_n_batch == 0 + ) + for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) params.batch_idx_train += 1 + + beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train)) batch_size = len(batch["supervisions"]["text"]) try: @@ -1135,9 +1142,7 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - model_prev=model_prev - if 0 < params.limit_grad_start_batch < params.batch_idx_train - else None, + model_prev=model_prev if is_grad_limit_enabled() else None, sp=sp, batch=batch, is_training=True, @@ -1155,14 +1160,10 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if ( - 0 < params.limit_grad_start_batch <= params.batch_idx_train - and params.batch_idx_train % params.limit_grad_every_n_batch == 0 - ): + if is_grad_limit_enabled(): if model_prev is None: model_prev = copy.deepcopy(model) else: - beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train)) update_model_prev(model_prev=model_prev, model=model, beta=beta) except Exception as e: