From 97df1ce3ebe65e2eaf47b4c576c71d17fda714c0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 30 Oct 2024 19:21:46 +0800 Subject: [PATCH] Save rng states. --- egs/librispeech/ASR/zipformer/model.py | 87 +++++++++++++++++++ .../ASR/zipformer/train-limit-grad.py | 46 +++++++++- 2 files changed, 130 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index c7dbe1e0a..b8f7d6336 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random from typing import Optional, Tuple import k2 @@ -159,6 +160,9 @@ class AsrModel(nn.Module): encoder_out_lens: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, + encoder_out_prev: Optional[torch.Tensor] = None, + encoder_out_lens_prev: Optional[torch.Tensor] = None, + model_prev=None, ) -> torch.Tensor: """Compute CTC loss. Args: @@ -170,8 +174,43 @@ class AsrModel(nn.Module): Target Tensor of shape (sum(target_lengths)). The targets are assumed to be un-padded and concatenated within 1 dimension. """ + device = encoder_out.device + if model_prev: + cpu_state = torch.get_rng_state() + cuda_state = torch.cuda.get_rng_state(device) + rng_state = random.getstate() + # Compute CTC log-prob ctc_output = self.ctc_output(encoder_out) # (N, T, C) + print( + "ctc_output", + ctc_output.detach().mean(), + ctc_output.detach().sum(), + ctc_output.detach().min(), + ctc_output.detach().max(), + ) + + if model_prev: + with torch.random.fork_rng(devices=[device]): + torch.set_rng_state(cpu_state) + torch.cuda.set_rng_state(cuda_state, device) + + rng_state2 = random.getstate() + random.setstate(rng_state) + + ctc_output_prev = model_prev.ctc_output(encoder_out) + random.setstate(rng_state2) + print( + "ctc_output_prev", + ctc_output_prev.detach().mean(), + ctc_output_prev.detach().sum(), + ctc_output_prev.detach().min(), + ctc_output_prev.detach().max(), + ) + print( + "isclose ctc", + (ctc_output - ctc_output).detach().abs().max(), + ) ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) @@ -345,6 +384,7 @@ class AsrModel(nn.Module): spec_augment: Optional[SpecAugment] = None, supervision_segments: Optional[torch.Tensor] = None, time_warp_factor: Optional[int] = 80, + model_prev=None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -418,9 +458,53 @@ class AsrModel(nn.Module): x_lens = x_lens.repeat(2) y = k2.ragged.cat([y, y], axis=0) + device = x.device + if model_prev: + cpu_state = torch.get_rng_state() + cuda_state = torch.cuda.get_rng_state(device) + rng_state = random.getstate() + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + print( + "encoder_out", + encoder_out.detach().mean(), + encoder_out.detach().abs().max(), + encoder_out.detach().abs().min(), + encoder_out.detach().sum(), + encoder_out.shape, + ) + + if model_prev: + with torch.random.fork_rng(devices=[device]): + torch.set_rng_state(cpu_state) + torch.cuda.set_rng_state(cuda_state, device) + + rng_state2 = random.getstate() + random.setstate(rng_state) + + encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder( + x, x_lens + ) + random.setstate(rng_state2) + print( + "encoder_out_prev", + encoder_out_prev.detach().mean(), + encoder_out_prev.detach().abs().max(), + encoder_out_prev.detach().abs().mean(), + encoder_out_prev.detach().sum(), + encoder_out_prev.shape, + ) + print( + "isclose", + (encoder_out - encoder_out_prev).detach().abs().max(), + (encoder_out_lens - encoder_out_lens_prev).detach().abs().max(), + ) + else: + encoder_out_prev = None + encoder_out_lens_prev = None + row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -451,6 +535,9 @@ class AsrModel(nn.Module): encoder_out_lens=encoder_out_lens, targets=targets, target_lengths=y_lens, + encoder_out_prev=encoder_out_prev, + encoder_out_lens_prev=encoder_out_lens_prev, + model_prev=model_prev, ) cr_loss = torch.empty(0) else: diff --git a/egs/librispeech/ASR/zipformer/train-limit-grad.py b/egs/librispeech/ASR/zipformer/train-limit-grad.py index c074c32ec..964adeede 100755 --- a/egs/librispeech/ASR/zipformer/train-limit-grad.py +++ b/egs/librispeech/ASR/zipformer/train-limit-grad.py @@ -549,6 +549,14 @@ def get_parser(): help="Whether to use bf16 in AMP.", ) + parser.add_argument( + "--limit-grad-start-batch", + type=int, + # default=1000, + default=2, + help="Limit grad starting from this batch.", + ) + add_model_arguments(parser) return parser @@ -879,6 +887,7 @@ def compute_loss( batch: dict, is_training: bool, spec_augment: Optional[SpecAugment] = None, + model_prev: Union[nn.Module, DDP] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -942,6 +951,7 @@ def compute_loss( spec_augment=spec_augment, supervision_segments=supervision_segments, time_warp_factor=params.spec_aug_time_warp_factor, + model_prev=model_prev, ) loss = 0.0 @@ -1037,6 +1047,7 @@ def train_one_epoch( scaler: GradScaler, spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, + model_prev: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -1104,9 +1115,14 @@ def train_one_epoch( with torch.cuda.amp.autocast( enabled=params.use_autocast, dtype=params.dtype ): + if params.batch_idx_train > params.limit_grad_start_batch: + model_prev = copy.deepcopy(model) loss, loss_info = compute_loss( params=params, model=model, + model_prev=model_prev + if params.batch_idx_train > params.limit_grad_start_batch + else None, sp=sp, batch=batch, is_training=True, @@ -1123,6 +1139,19 @@ def train_one_epoch( scaler.step(optimizer) scaler.update() optimizer.zero_grad() + + if params.batch_idx_train >= params.limit_grad_start_batch: + if model_prev is None: + model_prev = copy.deepcopy(model) + else: + model_prev = copy.deepcopy(model) + print( + "here", + params.batch_idx_train, + params.limit_grad_start_batch, + model_prev is None, + ) + except Exception as e: logging.info(f"Caught exception: {e}.") save_bad_model() @@ -1208,7 +1237,7 @@ def train_one_epoch( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if batch_idx % params.valid_interval == 1000 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1233,6 +1262,8 @@ def train_one_epoch( params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss + return model_prev + def run(rank, world_size, args): """ @@ -1319,6 +1350,9 @@ def run(rank, world_size, args): # model_avg is only used with rank 0 model_avg = copy.deepcopy(model).to(torch.float64) + model_prev: Optional[nn.Module] = None + # TODO(fangjun): load checkpoint for model_prev + assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg @@ -1428,7 +1462,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if False and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, @@ -1453,10 +1487,11 @@ def run(rank, world_size, args): params.cur_epoch = epoch - train_one_epoch( + model_prev = train_one_epoch( params=params, model=model, model_avg=model_avg, + model_prev=model_prev, optimizer=optimizer, scheduler=scheduler, sp=sp, @@ -1587,4 +1622,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) if __name__ == "__main__": + # torch.use_deterministic_algorithms(True, warn_only=True) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.enabled = False + main()