Orthogonalize every 2k iters

This commit is contained in:
Daniel Povey 2022-05-15 21:50:40 +08:00
parent bb32556f9e
commit cee5396058
2 changed files with 100 additions and 7 deletions

View File

@ -19,7 +19,7 @@ import copy
import math
import warnings
from typing import Optional, Tuple
import logging
import torch
from encoder_interface import EncoderInterface
from scaling import (
@ -127,6 +127,11 @@ class Conformer(EncoderInterface):
return x, lengths
def orthogonalize(self) -> None:
# currently only orthogonalize the self-attention modules, but could in principle
# do more layers.
for m in self.encoder.layers:
m.self_attn.orthogonalize()
class ConformerEncoderLayer(nn.Module):
"""
@ -321,6 +326,7 @@ class ConformerEncoder(nn.Module):
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
@ -829,6 +835,79 @@ class RelPositionMultiheadAttention(nn.Module):
else:
return attn_output, None
@torch.no_grad()
def orthogonalize(self) -> None:
"""
Rotate some parameters to try to segregate large vs. small rows/columns of
the parameter matrices. The intention is to improve the convergence of
Adam-type update formulas (with SGD, the update would be invariant to
such rotations).
"""
num_heads = self.num_heads
attn_dim = self.in_proj.weight.shape[1]
query_proj = self.in_proj.weight[0:attn_dim,:].chunk(num_heads, dim=0)
key_proj = self.in_proj.weight[attn_dim:2*attn_dim,:].chunk(num_heads, dim=0)
linear_pos_proj = self.linear_pos.weight.chunk(num_heads, dim=0)
value_proj = self.in_proj.weight[2*attn_dim:,:].chunk(num_heads, dim=0)
out_proj = self.out_proj.weight.t().chunk(num_heads, dim=0)
def get_normalized_covar(x: Tensor) -> Tensor:
"""
Returns a covariance matrix normalized to have trace==dim.
Args:
x: a matrix of shape (i, j)
Returns: a covariance matrix of shape (i, i), equal to matmul(x, x.t())
"""
covar = torch.matmul(x, x.t())
return covar * (x.shape[0] / covar.trace())
def get_proj(*args) -> Tensor:
"""
Returns a covariance-diagonalizing projection that diagonalizes
the summed covariance from these two projections. If mat1,mat2
are of shape (dim0, dim1), it's the (dim0, dim0) covariance,
that we diagonalize.
Args: mat1, mat2, etc., which should all be matrices of the same
shape (dim0, dim1)
Returns: a projection indexed (new_dim0, old_dim0), i.e. of
shape dim0 by dim0 but 1st index is the newly created indexes.
"""
cov = get_normalized_covar(args[0])
for a in args[1:]:
cov += get_normalized_covar(a)
old_diag_stddev = cov.diag().var().sqrt().item()
l, U = cov.symeig(eigenvectors=True)
new_diag_stddev = l.var().sqrt().item()
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
f"to {new_diag_stddev:.3e}")
return U.t() # U.t() is indexed (new_dim, old_dim)
# first do the query/key space
logging.info("Diagonalizing query/key space")
for i in range(num_heads):
q,k,l,pos_u,pos_v = query_proj[i], key_proj[i], linear_pos_proj[i], self.pos_bias_u[i], self.pos_bias_v[i]
qk_proj = get_proj(q, k, l)
q[:] = torch.matmul(qk_proj, q)
k[:] = torch.matmul(qk_proj, k)
l[:] = torch.matmul(qk_proj, l)
pos_u[:] = torch.mv(qk_proj, pos_u)
pos_v[:] = torch.mv(qk_proj, pos_v)
# Now do the value space
logging.info("Diagonalizing value space")
for i in range(num_heads):
v, o = value_proj[i], out_proj[i]
v_proj = get_proj(v, o)
v[:] = torch.matmul(v_proj, v)
o[:] = torch.matmul(v_proj, o)
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
@ -1026,13 +1105,22 @@ class Conv2dSubsampling(nn.Module):
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)
# Make sure the forward pass runs, and that orthogonalize() does not
# change its output.
feats = torch.randn(batch_size, seq_len, feature_dim)
x_lens = torch.full((batch_size,), seq_len, dtype=torch.int64)
c.eval() # use test mode and warmup=1.0 so it is deterministic.
y1 = c(feats, x_lens, warmup=1.0)[0]
c.orthogonalize()
y2 = c(feats, x_lens, warmup=1.0)[0]
diff_norm = (y1-y2).norm()
y_norm = y1.norm()
print(f"diff_norm={diff_norm}, y_norm={y_norm}")
assert diff_norm < 0.001 * y_norm

View File

@ -701,6 +701,11 @@ def train_one_epoch(
continue
cur_batch_idx = batch_idx
if params.batch_idx_train % 2000 == 0 and params.batch_idx_train > 0:
mmodel = model.module if hasattr(model, 'module') else model
mmodel.encoder.orthogonalize()
optimizer.reset()
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])