mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
More fixes to the checkpoint code. (#266)
This commit is contained in:
parent
6a091da0b0
commit
3ae7265737
@ -392,14 +392,17 @@ def load_checkpoint_if_available(
|
|||||||
"batch_idx_train",
|
"batch_idx_train",
|
||||||
"best_train_loss",
|
"best_train_loss",
|
||||||
"best_valid_loss",
|
"best_valid_loss",
|
||||||
"cur_batch_idx",
|
|
||||||
]
|
]
|
||||||
for k in keys:
|
for k in keys:
|
||||||
params[k] = saved_params.get(k, 0)
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
if params.start_batch > 0:
|
||||||
if "cur_epoch" in saved_params:
|
if "cur_epoch" in saved_params:
|
||||||
params["start_epoch"] = saved_params["cur_epoch"]
|
params["start_epoch"] = saved_params["cur_epoch"]
|
||||||
|
|
||||||
|
if "cur_batch_idx" in saved_params:
|
||||||
|
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||||
|
|
||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
@ -784,6 +787,13 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
|
#
|
||||||
|
# Caution: There is a reason to select 20.0 here. Please see
|
||||||
|
# ../local/display_manifest_statistics.py
|
||||||
|
#
|
||||||
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
|
# an utterance duration distribution for your dataset to select
|
||||||
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 20.0
|
return 1.0 <= c.duration <= 20.0
|
||||||
|
|
||||||
num_in_total = len(train_cuts)
|
num_in_total = len(train_cuts)
|
||||||
@ -798,7 +808,9 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"After removing short and long utterances: {num_left}")
|
logging.info(f"After removing short and long utterances: {num_left}")
|
||||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||||
|
|
||||||
if checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
|
# saved in the middle of an epoch
|
||||||
sampler_state_dict = checkpoints["sampler"]
|
sampler_state_dict = checkpoints["sampler"]
|
||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user