From 256c446f0620658f7f5ab7124c2e854bed5e933d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 30 Oct 2024 21:11:07 +0800 Subject: [PATCH] First working version --- egs/librispeech/ASR/zipformer/model.py | 48 +++---------------- egs/librispeech/ASR/zipformer/scaling.py | 16 +++++-- .../ASR/zipformer/train-limit-grad.py | 47 ++++++++++-------- 3 files changed, 46 insertions(+), 65 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index ab216a2ae..baeffad88 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -25,7 +25,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface from lhotse.dataset import SpecAugment -from scaling import ScaledLinear +from scaling import ScaledLinear, scale_grad from icefall.utils import add_sos, make_pad_mask, time_warp @@ -198,13 +198,6 @@ class AsrModel(nn.Module): # Compute CTC log-prob ctc_output = self.ctc_output(encoder_out) # (N, T, C) - print( - "ctc_output", - ctc_output.detach().mean(), - ctc_output.detach().sum(), - ctc_output.detach().min(), - ctc_output.detach().max(), - ) if model_prev: with fork_rng( @@ -213,18 +206,11 @@ class AsrModel(nn.Module): rng_state=rng_state, device=device, ): - ctc_output_prev = model_prev.ctc_output(encoder_out) - print( - "ctc_output_prev", - ctc_output_prev.detach().mean(), - ctc_output_prev.detach().sum(), - ctc_output_prev.detach().min(), - ctc_output_prev.detach().max(), - ) - print( - "isclose ctc", - (ctc_output - ctc_output).detach().abs().max(), - ) + ctc_output_prev = model_prev.ctc_output(encoder_out_prev) + + has_grown = ctc_output > 0.8 * ctc_output_prev + grad_scale_tensor = torch.where(has_grown, 0.5, 1.0) + ctc_output = scale_grad(ctc_output, grad_scale_tensor) ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) @@ -481,15 +467,6 @@ class AsrModel(nn.Module): # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) - print( - "encoder_out", - encoder_out.detach().mean(), - encoder_out.detach().abs().max(), - encoder_out.detach().abs().min(), - encoder_out.detach().sum(), - encoder_out.shape, - ) - if model_prev: with fork_rng( cpu_state=cpu_state, @@ -500,19 +477,6 @@ class AsrModel(nn.Module): encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder( x, x_lens ) - print( - "encoder_out_prev", - encoder_out_prev.detach().mean(), - encoder_out_prev.detach().abs().max(), - encoder_out_prev.detach().abs().mean(), - encoder_out_prev.detach().sum(), - encoder_out_prev.shape, - ) - print( - "isclose", - (encoder_out - encoder_out_prev).detach().abs().max(), - (encoder_out_lens - encoder_out_lens_prev).detach().abs().max(), - ) else: encoder_out_prev = None encoder_out_lens_prev = None diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d345c2931..cf617ba51 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1136,16 +1136,24 @@ def with_loss(x, y, name): class ScaleGradFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, alpha: float) -> Tensor: - ctx.alpha = alpha + def forward(ctx, x: Tensor, alpha: Union[float, Tensor]) -> Tensor: + if isinstance(alpha, Tensor): + ctx.save_for_backward(alpha) + else: + ctx.alpha = alpha return x @staticmethod def backward(ctx, grad: Tensor): - return grad * ctx.alpha, None + if hasattr(ctx, "alpha"): + alpha = ctx.alpha + else: + (alpha,) = ctx.saved_tensors + + return grad * alpha, None -def scale_grad(x: Tensor, alpha: float): +def scale_grad(x: Tensor, alpha: Union[float, Tensor]): return ScaleGradFunction.apply(x, alpha) diff --git a/egs/librispeech/ASR/zipformer/train-limit-grad.py b/egs/librispeech/ASR/zipformer/train-limit-grad.py index 964adeede..cd74d78c0 100755 --- a/egs/librispeech/ASR/zipformer/train-limit-grad.py +++ b/egs/librispeech/ASR/zipformer/train-limit-grad.py @@ -552,9 +552,15 @@ def get_parser(): parser.add_argument( "--limit-grad-start-batch", type=int, - # default=1000, - default=2, - help="Limit grad starting from this batch.", + default=1000, + help="Enable grad limit starting from this batch. Set it to 0 to disable it", + ) + + parser.add_argument( + "--limit-grad-every-n-batch", + type=int, + default=1, + help="Apply grad limit every this number of batch when it is enabled", ) add_model_arguments(parser) @@ -1036,6 +1042,17 @@ def compute_validation_loss( return tot_loss +@torch.inference_mode() +def update_model_prev(model_prev, model, beta): + # model_prev = beta * model_prev + (1-beta) * model + model_prev_dict = model_prev.state_dict() + model_dict = model.state_dict() + for key in model_prev_dict: + model_prev_dict[key].data.copy_( + model_prev_dict[key].data * beta + model_dict[key].data * (1 - beta) + ) + + def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], @@ -1115,13 +1132,11 @@ def train_one_epoch( with torch.cuda.amp.autocast( enabled=params.use_autocast, dtype=params.dtype ): - if params.batch_idx_train > params.limit_grad_start_batch: - model_prev = copy.deepcopy(model) loss, loss_info = compute_loss( params=params, model=model, model_prev=model_prev - if params.batch_idx_train > params.limit_grad_start_batch + if 0 < params.limit_grad_start_batch < params.batch_idx_train else None, sp=sp, batch=batch, @@ -1140,17 +1155,15 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() - if params.batch_idx_train >= params.limit_grad_start_batch: + if ( + 0 < params.limit_grad_start_batch <= params.batch_idx_train + and params.batch_idx_train % params.limit_grad_every_n_batch == 0 + ): if model_prev is None: model_prev = copy.deepcopy(model) else: - model_prev = copy.deepcopy(model) - print( - "here", - params.batch_idx_train, - params.limit_grad_start_batch, - model_prev is None, - ) + beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train)) + update_model_prev(model_prev=model_prev, model=model, beta=beta) except Exception as e: logging.info(f"Caught exception: {e}.") @@ -1221,6 +1234,7 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + + (f", beta: {beta}" if model_prev is not None else "") ) if tb_writer is not None: @@ -1622,9 +1636,4 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) if __name__ == "__main__": - # torch.use_deterministic_algorithms(True, warn_only=True) - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.benchmark = False - # torch.backends.cudnn.enabled = False - main()