mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
test restore dynamic sampler
This commit is contained in:
parent
18a1e959f7
commit
0930748b61
@ -251,7 +251,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=8000,
|
||||
default=200,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
@ -333,7 +333,7 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"log_interval": 1,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
# parameters for conformer
|
||||
@ -682,9 +682,9 @@ def train_one_epoch(
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
for batch_idx, batch in enumerate(train_dl, cur_batch_idx):
|
||||
# if batch_idx < cur_batch_idx:
|
||||
# continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
@ -909,6 +909,7 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
"""
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
@ -917,7 +918,7 @@ def run(rank, world_size, args):
|
||||
graph_compiler=graph_compiler,
|
||||
params=params,
|
||||
)
|
||||
|
||||
"""
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
|
Loading…
x
Reference in New Issue
Block a user