Reduce threshold to 1024

This commit is contained in:
Daniel Povey 2022-06-01 14:42:56 +08:00
parent ca09b9798f
commit 0c73664aef

View File

@ -313,7 +313,7 @@ class Cain(Optimizer):
p_final_normalized = self._change_coordinates(p_final, state,
forward=True)
# now do block-by-block normalization
# now do block-by-block normalization (TODO- not done yet)
p_averaged_normalized = self._change_coordinates(p_averaged,
@ -341,7 +341,7 @@ class Cain(Optimizer):
# see for each dim in turn whether we want to perform any changes in co-ordinates,
# or store any stats.
size = grad.shape[dim]
if size <= 3 or (size % 2 == 1 and size < 128) or size >= 2048 or size == numel:
if size <= 3 or (size % 2 == 1 and size < 128) or size >= 1024 or size == numel:
# we don't do any such co-ordinate changes in dims with sizes
# that are too small (no point) or large (too slow), or that are
# assumed convolutional (because they are odd and not too huge).
@ -769,6 +769,18 @@ def average_checkpoints_with_averaged_model_rebalancing(model: torch.nn.Module,
optimizer = Cain(model.parameters())
averaged_state_dict = checkpoint.average_checkpoints_with_averaged_model(
filename_start, filename_end, device, decompose=True)
model.load_state_dict(averaged_state_dict)
c = torch.load(filename_end, map_location=device)
final_state_dict = c["model"]
# attach the final model's parameters as grads, so they will be visible
# to the optimizer object when we call rebalance()
for name, param in model.named_parameters():
param.grad = final_state_dict[name]