mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
test restore dynamic sampler
This commit is contained in:
parent
18a1e959f7
commit
0930748b61
@ -251,7 +251,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
default=8000,
|
default=200,
|
||||||
help="""Save checkpoint after processing this number of batches"
|
help="""Save checkpoint after processing this number of batches"
|
||||||
periodically. We save checkpoint to exp-dir/ whenever
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
@ -333,7 +333,7 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 1,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000,
|
"valid_interval": 3000,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
@ -682,9 +682,9 @@ def train_one_epoch(
|
|||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl, cur_batch_idx):
|
||||||
if batch_idx < cur_batch_idx:
|
# if batch_idx < cur_batch_idx:
|
||||||
continue
|
# continue
|
||||||
cur_batch_idx = batch_idx
|
cur_batch_idx = batch_idx
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
@ -909,6 +909,7 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
@ -917,7 +918,7 @@ def run(rank, world_size, args):
|
|||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user