Fix bug regarding --start-batch option

This commit is contained in:
Daniel Povey 2023-05-29 16:41:54 +08:00
parent cbd59b9c68
commit cdd9cf695f
2 changed files with 5 additions and 2 deletions

View File

@ -709,6 +709,7 @@ def train_one_epoch(
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
batch_idx_offset: int = 0,
) -> None: ) -> None:
"""Train the model for one epoch. """Train the model for one epoch.
@ -759,7 +760,8 @@ def train_one_epoch(
rank=0) rank=0)
for batch_idx, batch in enumerate(train_dl): for batch_idx_, batch in enumerate(train_dl):
batch_idx = batch_idx_ + batch_idx_offset
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params)) set_batch_count(model, get_adjusted_batch_count(params))
@ -1038,6 +1040,7 @@ def run(rank, world_size, args):
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
batch_idx_offset=(getattr(params, 'cur_batch_idx', 0) if epoch == params.start_epoch else 0),
) )
if params.print_diagnostics: if params.print_diagnostics:

View File

@ -896,7 +896,7 @@ class AbsValuePenalizer(nn.Module):
self.mem_cutoff = CutoffEstimator(0.2) self.mem_cutoff = CutoffEstimator(0.2)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or if (torch.jit.is_scripting() or not x.requires_grad
or not self.training or not self.training
or random.random() > self.prob): or random.random() > self.prob):
# or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) # or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))