mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Reduce threshold to 1024
This commit is contained in:
parent
ca09b9798f
commit
0c73664aef
@ -313,7 +313,7 @@ class Cain(Optimizer):
|
||||
p_final_normalized = self._change_coordinates(p_final, state,
|
||||
forward=True)
|
||||
|
||||
# now do block-by-block normalization
|
||||
# now do block-by-block normalization (TODO- not done yet)
|
||||
|
||||
|
||||
p_averaged_normalized = self._change_coordinates(p_averaged,
|
||||
@ -341,7 +341,7 @@ class Cain(Optimizer):
|
||||
# see for each dim in turn whether we want to perform any changes in co-ordinates,
|
||||
# or store any stats.
|
||||
size = grad.shape[dim]
|
||||
if size <= 3 or (size % 2 == 1 and size < 128) or size >= 2048 or size == numel:
|
||||
if size <= 3 or (size % 2 == 1 and size < 128) or size >= 1024 or size == numel:
|
||||
# we don't do any such co-ordinate changes in dims with sizes
|
||||
# that are too small (no point) or large (too slow), or that are
|
||||
# assumed convolutional (because they are odd and not too huge).
|
||||
@ -769,6 +769,18 @@ def average_checkpoints_with_averaged_model_rebalancing(model: torch.nn.Module,
|
||||
|
||||
optimizer = Cain(model.parameters())
|
||||
|
||||
averaged_state_dict = checkpoint.average_checkpoints_with_averaged_model(
|
||||
filename_start, filename_end, device, decompose=True)
|
||||
model.load_state_dict(averaged_state_dict)
|
||||
|
||||
c = torch.load(filename_end, map_location=device)
|
||||
final_state_dict = c["model"]
|
||||
|
||||
# attach the final model's parameters as grads, so they will be visible
|
||||
# to the optimizer object when we call rebalance()
|
||||
for name, param in model.named_parameters():
|
||||
param.grad = final_state_dict[name]
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user