diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index fddf6e59d..f3f70481a 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -709,6 +709,7 @@ def train_one_epoch( tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, + batch_idx_offset: int = 0, ) -> None: """Train the model for one epoch. @@ -759,7 +760,8 @@ def train_one_epoch( rank=0) - for batch_idx, batch in enumerate(train_dl): + for batch_idx_, batch in enumerate(train_dl): + batch_idx = batch_idx_ + batch_idx_offset if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) @@ -1038,6 +1040,7 @@ def run(rank, world_size, args): tb_writer=tb_writer, world_size=world_size, rank=rank, + batch_idx_offset=(getattr(params, 'cur_batch_idx', 0) if epoch == params.start_epoch else 0), ) if params.print_diagnostics: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 44ca6e0a8..9839a3fb9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -896,7 +896,7 @@ class AbsValuePenalizer(nn.Module): self.mem_cutoff = CutoffEstimator(0.2) def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or + if (torch.jit.is_scripting() or not x.requires_grad or not self.training or random.random() > self.prob): # or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))