mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix bug regarding --start-batch option
This commit is contained in:
parent
cbd59b9c68
commit
cdd9cf695f
@ -709,6 +709,7 @@ def train_one_epoch(
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
batch_idx_offset: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
@ -759,7 +760,8 @@ def train_one_epoch(
|
||||
rank=0)
|
||||
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
for batch_idx_, batch in enumerate(train_dl):
|
||||
batch_idx = batch_idx_ + batch_idx_offset
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
|
||||
@ -1038,6 +1040,7 @@ def run(rank, world_size, args):
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
batch_idx_offset=(getattr(params, 'cur_batch_idx', 0) if epoch == params.start_epoch else 0),
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
|
||||
@ -896,7 +896,7 @@ class AbsValuePenalizer(nn.Module):
|
||||
self.mem_cutoff = CutoffEstimator(0.2)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting() or not x.requires_grad or
|
||||
if (torch.jit.is_scripting() or not x.requires_grad
|
||||
or not self.training
|
||||
or random.random() > self.prob):
|
||||
# or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user