mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
skip scan_pessimistic_batches_for_oom if params.start_batch > 0
This commit is contained in:
parent
efdbb98583
commit
728eb5075f
@ -873,13 +873,14 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
scan_pessimistic_batches_for_oom(
|
if params.start_batch <= 0:
|
||||||
model=model,
|
scan_pessimistic_batches_for_oom(
|
||||||
train_dl=train_dl,
|
model=model,
|
||||||
optimizer=optimizer,
|
train_dl=train_dl,
|
||||||
sp=sp,
|
optimizer=optimizer,
|
||||||
params=params,
|
sp=sp,
|
||||||
)
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch)
|
||||||
|
@ -924,7 +924,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if params.start_batch <= 0 and not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
|
@ -1040,14 +1040,15 @@ def run(rank, world_size, args):
|
|||||||
# It's time consuming to include `giga_train_dl` here
|
# It's time consuming to include `giga_train_dl` here
|
||||||
# for dl in [train_dl, giga_train_dl]:
|
# for dl in [train_dl, giga_train_dl]:
|
||||||
for dl in [train_dl]:
|
for dl in [train_dl]:
|
||||||
scan_pessimistic_batches_for_oom(
|
if params.start_batch <= 0:
|
||||||
model=model,
|
scan_pessimistic_batches_for_oom(
|
||||||
train_dl=dl,
|
model=model,
|
||||||
optimizer=optimizer,
|
train_dl=dl,
|
||||||
sp=sp,
|
optimizer=optimizer,
|
||||||
params=params,
|
sp=sp,
|
||||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
params=params,
|
||||||
)
|
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||||
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
@ -973,7 +973,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if params.start_batch <= 0 and not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
|
@ -962,7 +962,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if params.start_batch <= 0 and not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
|
@ -507,9 +507,6 @@ def load_checkpoint_if_available(
|
|||||||
if "cur_epoch" in saved_params:
|
if "cur_epoch" in saved_params:
|
||||||
params["start_epoch"] = saved_params["cur_epoch"]
|
params["start_epoch"] = saved_params["cur_epoch"]
|
||||||
|
|
||||||
if "cur_batch_idx" in saved_params:
|
|
||||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
|
||||||
|
|
||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
@ -754,13 +751,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx < cur_batch_idx:
|
|
||||||
continue
|
|
||||||
cur_batch_idx = batch_idx
|
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
@ -802,7 +793,6 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train > 0
|
params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
):
|
):
|
||||||
params.cur_batch_idx = batch_idx
|
|
||||||
save_checkpoint_with_global_batch_idx(
|
save_checkpoint_with_global_batch_idx(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
global_batch_idx=params.batch_idx_train,
|
global_batch_idx=params.batch_idx_train,
|
||||||
@ -815,7 +805,6 @@ def train_one_epoch(
|
|||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
del params.cur_batch_idx
|
|
||||||
remove_checkpoints(
|
remove_checkpoints(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
topk=params.keep_last_k,
|
topk=params.keep_last_k,
|
||||||
@ -990,7 +979,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if params.start_batch <= 0 and not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user