Fix bug regarding --start-batch option

This commit is contained in:
Daniel Povey 2023-05-29 16:41:54 +08:00
parent cbd59b9c68
commit cdd9cf695f
2 changed files with 5 additions and 2 deletions

View File

@ -709,6 +709,7 @@ def train_one_epoch(
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
batch_idx_offset: int = 0,
) -> None:
"""Train the model for one epoch.
@ -759,7 +760,8 @@ def train_one_epoch(
rank=0)
for batch_idx, batch in enumerate(train_dl):
for batch_idx_, batch in enumerate(train_dl):
batch_idx = batch_idx_ + batch_idx_offset
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
@ -1038,6 +1040,7 @@ def run(rank, world_size, args):
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
batch_idx_offset=(getattr(params, 'cur_batch_idx', 0) if epoch == params.start_epoch else 0),
)
if params.print_diagnostics:

View File

@ -896,7 +896,7 @@ class AbsValuePenalizer(nn.Module):
self.mem_cutoff = CutoffEstimator(0.2)
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or
if (torch.jit.is_scripting() or not x.requires_grad
or not self.training
or random.random() > self.prob):
# or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))