More fixes to the checkpoint code. (#266)

This commit is contained in:
Fangjun Kuang 2022-03-23 14:37:54 +08:00 committed by GitHub
parent 6a091da0b0
commit 3ae7265737
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -392,13 +392,16 @@ def load_checkpoint_if_available(
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
"cur_batch_idx",
]
for k in keys:
params[k] = saved_params.get(k, 0)
params[k] = saved_params[k]
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -784,6 +787,13 @@ def run(rank, world_size, args):
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
@ -798,7 +808,9 @@ def run(rank, world_size, args):
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
if checkpoints and "sampler" in checkpoints:
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None