diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 257936b59..bcaa90509 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple - +import logging import torch from encoder_interface import EncoderInterface from scaling import ( @@ -127,6 +127,11 @@ class Conformer(EncoderInterface): return x, lengths + def orthogonalize(self) -> None: + # currently only orthogonalize the self-attention modules, but could in principle + # do more layers. + for m in self.encoder.layers: + m.self_attn.orthogonalize() class ConformerEncoderLayer(nn.Module): """ @@ -321,6 +326,7 @@ class ConformerEncoder(nn.Module): return output + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -829,6 +835,79 @@ class RelPositionMultiheadAttention(nn.Module): else: return attn_output, None + @torch.no_grad() + def orthogonalize(self) -> None: + """ + Rotate some parameters to try to segregate large vs. small rows/columns of + the parameter matrices. The intention is to improve the convergence of + Adam-type update formulas (with SGD, the update would be invariant to + such rotations). + """ + num_heads = self.num_heads + attn_dim = self.in_proj.weight.shape[1] + query_proj = self.in_proj.weight[0:attn_dim,:].chunk(num_heads, dim=0) + key_proj = self.in_proj.weight[attn_dim:2*attn_dim,:].chunk(num_heads, dim=0) + linear_pos_proj = self.linear_pos.weight.chunk(num_heads, dim=0) + value_proj = self.in_proj.weight[2*attn_dim:,:].chunk(num_heads, dim=0) + out_proj = self.out_proj.weight.t().chunk(num_heads, dim=0) + + def get_normalized_covar(x: Tensor) -> Tensor: + """ + Returns a covariance matrix normalized to have trace==dim. + Args: + x: a matrix of shape (i, j) + Returns: a covariance matrix of shape (i, i), equal to matmul(x, x.t()) + """ + covar = torch.matmul(x, x.t()) + return covar * (x.shape[0] / covar.trace()) + + + def get_proj(*args) -> Tensor: + """ + Returns a covariance-diagonalizing projection that diagonalizes + the summed covariance from these two projections. If mat1,mat2 + are of shape (dim0, dim1), it's the (dim0, dim0) covariance, + that we diagonalize. + + Args: mat1, mat2, etc., which should all be matrices of the same + shape (dim0, dim1) + + Returns: a projection indexed (new_dim0, old_dim0), i.e. of + shape dim0 by dim0 but 1st index is the newly created indexes. + """ + cov = get_normalized_covar(args[0]) + for a in args[1:]: + cov += get_normalized_covar(a) + old_diag_stddev = cov.diag().var().sqrt().item() + l, U = cov.symeig(eigenvectors=True) + new_diag_stddev = l.var().sqrt().item() + logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} " + f"to {new_diag_stddev:.3e}") + return U.t() # U.t() is indexed (new_dim, old_dim) + + # first do the query/key space + logging.info("Diagonalizing query/key space") + for i in range(num_heads): + q,k,l,pos_u,pos_v = query_proj[i], key_proj[i], linear_pos_proj[i], self.pos_bias_u[i], self.pos_bias_v[i] + qk_proj = get_proj(q, k, l) + q[:] = torch.matmul(qk_proj, q) + k[:] = torch.matmul(qk_proj, k) + l[:] = torch.matmul(qk_proj, l) + pos_u[:] = torch.mv(qk_proj, pos_u) + pos_v[:] = torch.mv(qk_proj, pos_v) + + # Now do the value space + logging.info("Diagonalizing value space") + for i in range(num_heads): + v, o = value_proj[i], out_proj[i] + v_proj = get_proj(v, o) + v[:] = torch.matmul(v_proj, v) + o[:] = torch.matmul(v_proj, o) + + + + + class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. @@ -1026,13 +1105,22 @@ class Conv2dSubsampling(nn.Module): if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) feature_dim = 50 c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup=0.5, - ) + # Make sure the forward pass runs, and that orthogonalize() does not + # change its output. + feats = torch.randn(batch_size, seq_len, feature_dim) + x_lens = torch.full((batch_size,), seq_len, dtype=torch.int64) + + c.eval() # use test mode and warmup=1.0 so it is deterministic. + y1 = c(feats, x_lens, warmup=1.0)[0] + c.orthogonalize() + y2 = c(feats, x_lens, warmup=1.0)[0] + + diff_norm = (y1-y2).norm() + y_norm = y1.norm() + print(f"diff_norm={diff_norm}, y_norm={y_norm}") + assert diff_norm < 0.001 * y_norm diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index a79f29c30..2eb9de4f3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -701,6 +701,11 @@ def train_one_epoch( continue cur_batch_idx = batch_idx + if params.batch_idx_train % 2000 == 0 and params.batch_idx_train > 0: + mmodel = model.module if hasattr(model, 'module') else model + mmodel.encoder.orthogonalize() + optimizer.reset() + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"])