Revert the exclusion of dim=500

This commit is contained in:
Daniel Povey 2022-05-28 17:49:16 +08:00
parent 0b645662f9
commit 295595d334

View File

@ -772,13 +772,11 @@ class Cain(Optimizer):
# see for each dim in turn whether we want to perform any changes in co-ordinates, # see for each dim in turn whether we want to perform any changes in co-ordinates,
# or store any stats. # or store any stats.
size = grad.shape[dim] size = grad.shape[dim]
if size <= 3 or size % 2 == 1 or size == 500 or size >= 2048 or size == numel: if size <= 3 or (size % 2 == 1 and size < 128) or size >= 2048 or size == numel:
# 500: exclude embedding dim, will later find a better way to do this.
# we don't do any such co-ordinate changes in dims with sizes # we don't do any such co-ordinate changes in dims with sizes
# that are too small (no point) or large (too slow), or that are # that are too small (no point) or large (too slow), or that are
# assumed convolutional (because they are odd). We can revisit # assumed convolutional (because they are odd and not too huge).
# this later. # We can revisit this later.
continue continue
grad = self._change_coordinates_for_dim(grad, state, dim, forward) grad = self._change_coordinates_for_dim(grad, state, dim, forward)
return grad return grad