mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
fix the mismatch in batch_idx_train (#1757)
This commit is contained in:
parent
fbba712887
commit
2653df5bda
@ -424,8 +424,12 @@ def average_checkpoints_with_averaged_model(
|
||||
state_dict_start = torch.load(filename_start, map_location=device)
|
||||
state_dict_end = torch.load(filename_end, map_location=device)
|
||||
|
||||
average_period = state_dict_start["average_period"]
|
||||
|
||||
batch_idx_train_start = state_dict_start["batch_idx_train"]
|
||||
batch_idx_train_start = (batch_idx_train_start // average_period) * average_period
|
||||
batch_idx_train_end = state_dict_end["batch_idx_train"]
|
||||
batch_idx_train_end = (batch_idx_train_end // average_period) * average_period
|
||||
interval = batch_idx_train_end - batch_idx_train_start
|
||||
assert interval > 0, interval
|
||||
weight_end = batch_idx_train_end / interval
|
||||
|
Loading…
x
Reference in New Issue
Block a user