diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 1d87f3f73..6d904e0bf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -128,6 +128,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of size_update_period: The periodicity, in steps, with which we update the size (scale) of the parameter tensor. This is provided to save a little time. lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices. + ** This is important for the speed/optimizaton tradeoff. ** + param_cov_period: The periodicity, in steps, with which we update the parameter covariance + stats. + max_block_size: The maximum block size in block-diagonal co-ordinate transformations. """ def __init__( self, @@ -147,7 +151,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr_update_period=200, grad_cov_period=3, param_cov_period=100, - max_fullcov_size=2048, + max_block_size=1024, ): @@ -168,7 +172,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr_update_period=lr_update_period, grad_cov_period=grad_cov_period, param_cov_period=param_cov_period, - max_fullcov_size=max_fullcov_size, + max_block_size=max_block_size, ) super(PrAdam, self).__init__(params, defaults) @@ -226,7 +230,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of """ eps = group["eps"] size_update_period = group["size_update_period"] - max_fullcov_size = group["max_fullcov_size"] + max_block_size = group["max_block_size"] state["step"] = 0 @@ -288,34 +292,51 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of if size == 1 or (size == numel and ignore_rank1_dims): continue - if size <= max_fullcov_size: - # Q_{dim} is be the learning-rate matrix for this - # dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate]. - # this is used twice in the update step, once transposed and once without transpose. - state[f"Q_{dim}"] = torch.eye(size, **kwargs).unsqueeze(0).expand( - batch_size, size, size).contiguous() + # Most of the time, (num_blocks, block_size) will equal (1, size) + num_blocks, block_size = self._get_block_size(size, max_block_size) - # param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating - # all other dims as a batch axis. - state[f"param_cov_{dim}"] = torch.zeros(batch_size, size, size, **kwargs) + # Q_{dim} is be the learning-rate matrix for this + # dimension, a matrix of indexed [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate]. + # this is used twice in the update step, once transposed and once without transpose. + Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand( + batch_size, num_blocks, block_size, block_size).contiguous() + state[f"Q_{dim}"] = Q + # param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating + # all other dims as a batch axis. + state[f"param_cov_{dim}"] = torch.zeros_like(Q) - # grad_cov_{dim} is the covariance of gradients on this axis (without - # any co-ordinate changes), treating all other axes as as a batch axis. - # This is needed when we re-diagonalize, and to compute the - # grad_rms_{dim}. We store it as a decaying average, decaying with beta2 + # grad_cov_{dim} is the covariance of gradients on this axis (without + # any co-ordinate changes), treating all other axes as as a batch axis. + # This is needed when we re-diagonalize, and to compute the + # grad_rms_{dim}. We store it as a decaying average, decaying with beta2 - # only for purposes of computing the scalar factor on the - # learning rate; and, as a result of this, also contributes something - # to the gradient w.r.t. f"Q_{dim}", which is one reason - # why we allocate a variable to keep track of its moving average - # instead of just using a temporary and smoothing the scalar factor. - state[f"grad_cov_{dim}"] = torch.zeros(batch_size, size, size, **kwargs) - else: - # diagonal-only Q and param_cov, no grad_cov needed because it is only - # needed for multiplying Q_{dim} by an orthogonal matrix, which is not - # applicable if Q_{dim} is diagonal. - state[f"Q_{dim}"] = torch.ones(batch_size, size, **kwargs) - state[f"param_cov_{dim}"] = torch.zeros(batch_size, size, **kwargs) + # only for purposes of computing the scalar factor on the + # learning rate; and, as a result of this, also contributes something + # to the gradient w.r.t. f"Q_{dim}", which is one reason + # why we allocate a variable to keep track of its moving average + # instead of just using a temporary and smoothing the scalar factor. + state[f"grad_cov_{dim}"] = torch.zeros_like(Q) + + + def _get_block_size(self, size: int, max_block_size: int) -> Tuple[int,int]: + """ + Returns information about the block size for a block-diagonal structure + of covariances and co-ordinate transformations. + + size: The size of the parameter tensor on one of its dimensions, e.g. + a channel dim or kernel size + max_block_size: The maximum block size allowed (of the blocks in a + block-diagonal structure) + Returns: + (num_blocks, block_size) + where nun_blocks * block_size == size and block_size <= max_block_size + """ + for num_blocks in range(1, size): + block_size = size // num_blocks + if block_size * num_blocks == size and block_size <= max_block_size: + return (num_blocks, block_size) + assert False, (size, max_block_size) # code error or e.g. negative or + # zero inputs def _step_one_batch(self, @@ -462,30 +483,23 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of except KeyError: continue # e.g. size == 1 or size == numel - if param_cov.ndim == 3: - # param_cov shape: (batch_size, size, size) + # param_cov shape: (batch_size, num_blocks, block_size, block_size) + (batch_size, num_blocks, block_size, block_size) = param_cov.shape - # M is the parameter matrix, of shape (-1, size) where the -1 covers - # all other dimensions of the tensor. - M_full = p.transpose(dim, -1) - M = M_full.reshape(batch_size, -1, size) + # M: (batch_size, num_blocks, x, block_size) + M = p.transpose(dim, -1).reshape(batch_size, -1, + num_blocks, block_size).transpose(1, 2) - # will be batched matrix multiply. shape of this_param_cov: (batch_size, size, size) - this_param_cov = torch.matmul(M.transpose(1,2), M) - # normalize scale of this param_cov, in case parameter scale - # changes significantly during training, which would cause some - # parts of the training timeline to be more highly weighted. - # shape of this_param_cov - this_param_cov /= _mean(_diag(this_param_cov), # _diag(this_param_cov) has shape (batch_size, size) - exclude_dims=[0], - keepdim=True).unsqueeze(-1) + eps # shape: (batch_size, 1, 1) - else: - # this_param_cov dim: (batch_size, size) - this_param_cov = _mean(p**2, - exclude_dims=[0,dim]) - this_param_cov /= _mean(this_param_cov, - exclude_dims=[0], - keepdim=True) + eps # shape: (batch_size, 1) + # will be batched matrix multiply. shape of this_param_cov: + # (batch_size, num_blocks, block_size, block_size) + this_param_cov = torch.matmul(M.transpose(2, 3), M) + + # normalize scale of this param_cov, in case parameter scale + # changes significantly during training, which would cause some + # parts of the training timeline to be more highly weighted. + # shape of this_param_cov + # expression after /= has shape (batch_size, num_blocks, 1, 1) + this_param_cov /= _diag(this_param_cov).mean(dim=[0,1], keepdim=True).unsqueeze(-1) + eps param_cov.mul_(1-this_weight).add_(this_param_cov, alpha=this_weight) @@ -543,23 +557,30 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of assert size == 1 or size == numel, size continue # e.g. size == 1 or size == numel: - if param_cov.ndim == 2: - # size > max_fullcov_size, we do not accumulate full covariance - # matrix for this dim and will use a diagonal Q. - # param_cov.shape is (batch_size, size). - Q[:] = 1.0 - S = param_cov # (batch_size, size) - else: - # if batch==True, param_cov.shape == (batch_size, size, size), and - # U, S and V have an extra leading dim. - U, S, V = _svd(param_cov) - Q[:] = U.transpose(1, 2) + # param_cov has the same shape as Q + (batch_size, num_blocks, block_size, block_size) = Q.shape + # U,S,V have extra leading dimensions, i.e. of shape + # {U,V}.shape == (batch_size, num_blocks, block_size, block_size), + # S.shape == (batch_size, num_blocks, block_size). + U, S, V = _svd(param_cov) + Q[:] = U.transpose(2, 3) - M = cur_p.transpose(dim, -1) # (batch_size, x, y, z, size) - while U.ndim < M.ndim: - U = U.unsqueeze(1) # (batch_size, 1, 1, size, size) - M = torch.matmul(M, U) # (batch_size, x, y, z, size) - cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z) + M = cur_p.transpose(dim, -1) + # if p were of shape (batch_size, x, size, y, z), + # after the next line M will be of shape + # (batch_size, x, y, z, num_blocks, block_size) + M = M.reshape(batch_size, *M.shape[1:-1], + num_blocks, block_size) + M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size) + + while U.ndim < M.ndim: + U = U.unsqueeze(1) + # Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size) + + M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size) + M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) + M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size) + cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z) # cur_param_var is a diagonal parameter variance over dimension `dim`, # of the current "slightly-whitened" parameter; it @@ -578,8 +599,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # (where the stats permit). scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt() logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}") - # scale: (batch_size, 1, size, 1, 1) if dim==2 + # scale shape: (batch_size, 1, size, 1, 1) if dim==2 cur_p *= scale # OK, at this point we have a matrix cur_p that is (somewhat) @@ -603,7 +624,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # cur_scales for the other dims. cur_scales = [None] * ndim - debug = True #(random.random() < 0.001) for i in range(4): # for 4 iterations (this is quite arbitrary) for dim in range(1, ndim): @@ -637,12 +657,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of if cur_scales[dim] is not None: size = p.shape[dim] Q = state[f"Q_{dim}"] + (batch_size, num_blocks, block_size, block_size) = Q.shape - scale = cur_scales[dim].reshape(batch_size, size) - if Q.ndim == 3: - # Q is indexed [batch_index, diagonalized_coordinate, canonical_coordinate], - # want to multiply on the diagonalized co-ordinate. - scale = scale.unsqueeze(-1) + scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1) + # Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate], + # want to multiply on the diagonalized co-ordinate. # else: Q is indexed [batch_index, canonical_coordinate]. state[f"Q_{dim}"] *= scale @@ -658,15 +677,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of by left-multiplying the projections by an orthogonal matrix. """ batch_size = p.shape[0] - max_fullcov_size = group["max_fullcov_size"] numel = p.numel() // batch_size for dim in range(1, p.ndim): # dim 0 is batch dim size = p.shape[dim] try: + # A and grad_cov shape: (batch_size, num_blocks, block_size, block_size) Q = state[f"Q_{dim}"] - grad_cov = state[f"grad_cov_{dim}"] # grad_cov shape: (batch_size, size_size) + grad_cov = state[f"grad_cov_{dim}"] except KeyError: - assert size == 1 or size == numel or size > max_fullcov_size + assert size == 1 or size == numel continue # Suppose the actual parameter matrix p is M, of shape (-1, size), where @@ -692,15 +711,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # a ratio that, if the eigenvalues are more dispersed, it will be larger. return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item() - # N_grad_cov shape: (batch_size, size, size) - N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(1, 2))) - N_grad_cov = N_grad_cov + N_grad_cov.transpose(1, 2) # ensure symmetric + + # N_grad_cov shape: (batch_size, num_blocks, block_size, block_size) + N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3))) + N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric U, S, V = _svd(N_grad_cov) if random.random() < 0.001: logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion " f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}") - # N_grad_cov is SPD, so # N_grad_cov = U S U^T. @@ -725,7 +744,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # # This is the only thing we have to do, as N is implicit # and not materialized at this point. - Q[:] = torch.matmul(U.transpose(1, 2), Q) + Q[:] = torch.matmul(U.transpose(2, 3), Q) def _update_grad_cov(self, group: dict, @@ -742,10 +761,24 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of size = grad.shape[dim] name = f"grad_cov_{dim}" if name not in state: - continue # e.g. size==1, size == numel, size > max_fullcov_size - grad_cov = state[name] # shaped: (batch_size, size, size) - this_g = grad.transpose(-1, dim).reshape(batch_size, -1, size) - grad_cov.mul_(beta2).add_(torch.matmul(this_g.transpose(1,2), this_g)) + continue # e.g. size==1, size == numel, size > max_block_size + grad_cov = state[name] # shaped: (batch_size, num_blocks, block_size, block_size) + (batch_size, num_blocks, block_size, block_size) = grad_cov.shape + + g = grad.transpose(-1, dim) + # if grad were of shape (batch_size, x, size, y, z), + # and g of shape (batch_size, x, y, z, size), + # after the next line g will be of shape + # (batch_size, x, y, z, num_blocks, block_size) + g = g.reshape(batch_size, *g.shape[1:-1], + num_blocks, block_size) + g = _move_dim(g, -2, 1) # (batch_size, num_blocks, x, y, z, block_size) + g = g.reshape(batch_size, num_blocks, -1, block_size) + + # this_grad_cov: (batch_size, num_blocks, block_size, block_size) + this_grad_cov = torch.matmul(g.transpose(-2, -1), g) + + grad_cov.mul_(beta2).add_(this_grad_cov, alpha=(1-beta2)) def _step(self, group: dict, @@ -771,7 +804,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr = group["lr"] beta1, beta2 = group["betas"] eps = group["eps"] - grad_cov_period = group["grad_cov_period"] step = state["step"] grad = self._project(grad, state, forward=True) @@ -843,55 +875,53 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of p.add_(delta) def _project(self, - x: Tensor, + M: Tensor, state: dict, forward: bool) -> Tensor: """ - Multiply a tensor x by proj_{dim} on each of its dimensions for which we have - a projection Q. - If forward == True, converts x from canonical to preconditioned/diagonalized - co-ordinates; if forward == False, does the reverse. + Multiply a tensor M by a projection Q on each of its dimensions (for which we have + such a projection) Args: - x: The tensor to project, e.g. a gradient or a parameter change. + M: The tensor to project, e.g. a gradient or a parameter change. state: dict to look up the projections forward: if True, go in the forward direction (from canonical to diagnonalized co-ordinates); if False, the reverse. They differ by a transpose, not an inverse, and do not make a round trip. """ - numel = x.numel() - for dim in range(1, x.ndim): # dim 0 is batch dim + numel = M.numel() + + for dim in range(1, M.ndim): # dim 0 is batch dim (batch of parameter + # tensors of same shape) try: - Q = state[f"Q_{dim}"] # shape: (batch_size, size, size) + Q = state[f"Q_{dim}"] # shape: (batch_size, num_blocks, block_size, block_size) except KeyError: continue # no projection for this dim - fullcov = (Q.ndim == 3) # (batch_index, diagonalized_index, canonical_index), else - # (batch_index, canonical_index). + (batch_size, num_blocks, block_size, block_size) = Q.shape + size = M.shape[dim] # == num_blocks * block_size - if fullcov: - if forward: - # Q is indexed [batch_index, diagonalized_index, canonical_index]; in the forward - # direction we want to change canonical to diagonalized index, so have - # to transpose. - Q = Q.transpose(1, 2) - # TODO: could possibly somehow force the output memory format to be - # unchanged. - while Q.ndim < x.ndim: - Q = Q.unsqueeze(1) - # now x might have shape (batch_size, 3, 4, 5, 6) - # and Q might have shape (batch_size, 1, 1, 6, 6) - # ... and the output will still have shape (batch_size, 3, 4, 5, 6) - x = x.transpose(-1, dim) - x = torch.matmul(x, Q) - x = x.transpose(-1, dim) - else: - shape = [1] * x.ndim - shape[0] = x.shape[0] # batch_size - shape[dim] = x.shape[dim] # size - Q = Q.reshape(shape) # e.g. (batch_size, 1, size, 1, 1) - x = x * Q - return x + if forward: + # Q is indexed + # [batch_idx, block_idx, diagonalized_idx, canonical_idx]; in + # the forward direction we want to change canonical to + # diagonalized index, so we transpose. + Q = Q.transpose(2, 3) + + # assume M currently has shape (batch_size, x, size, y, z), and dim==2. + M = M.transpose(-1, dim) # (batch_size, x, y, z, size) + M = M.reshape(batch_size, *M.shape[1:-1], + num_blocks, block_size) # (batch_size, x, y, z, num_blocks, block_size) + M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size) + + while Q.ndim < M.ndim: + Q = Q.unsqueeze(2) + # now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size) + M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size) + M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) + M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size) + M = M.transpose(-1, dim) # (batch_size, x, size, y, z) + return M def _smooth_param_rms(self, group: dict, @@ -941,6 +971,31 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of return ans +def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor: + """ + Moves the position of one dimension in a Tensor, while keeping + the remaining dims in the same order. E.g. if x has + shape [10, 1, 2, 3], _move_dim(x, 3, 1) will have shape + [10, 3, 1, 2]. Returns the reshaped tensor which will be + a view of the original Tensor. + Args: + x: Tensor to reshape + orig_dim: Original dimension to move, with -x.ndim <= orig_dim < x.ndim + new_dim: New position of the original dimension, with + -x.ndim <= new_dim < x.ndim + Returns: a permuted view of x + """ + if orig_dim < 0: + orig_dim += x.ndim + if new_dim < 0: + new_dim += x.ndim + dims = list(range(x.ndim)) + dims[orig_dim] = -1 + dims.remove(-1) + dims.insert(new_dim, orig_dim) + return x.permute(dims) + + def _diag(x: Tensor): """ like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns @@ -951,6 +1006,11 @@ def _diag(x: Tensor): assert M == M2 stride = x.stride() return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous() + elif x.ndim == 4: + (B, C, M, M2) = x.shape + assert M == M2 + stride = x.stride() + return x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous() else: return x.diag() @@ -1676,7 +1736,7 @@ def _test_eve_cain(): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: optim = Cain(m.parameters(), lr=0.03) elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03) - elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_fullcov_size=150) + elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=100) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer()