test restore dynamic sampler

This commit is contained in:
luomingshuang 2022-04-22 14:42:09 +08:00
parent 18a1e959f7
commit 0930748b61

View File

@ -251,7 +251,7 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=8000,
default=200,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -333,7 +333,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"log_interval": 1,
"reset_interval": 200,
"valid_interval": 3000,
# parameters for conformer
@ -682,9 +682,9 @@ def train_one_epoch(
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
for batch_idx, batch in enumerate(train_dl, cur_batch_idx):
# if batch_idx < cur_batch_idx:
# continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
@ -909,6 +909,7 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
"""
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
@ -917,7 +918,7 @@ def run(rank, world_size, args):
graph_compiler=graph_compiler,
params=params,
)
"""
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")