Hopefully this finishes the full orthogonalization.

This commit is contained in:
Daniel Povey 2022-05-16 19:18:12 +08:00
parent 67f916e599
commit d61c8aa3bc
3 changed files with 148 additions and 28 deletions

View File

@ -31,6 +31,9 @@ from scaling import (
ScaledLinear,
)
from torch import Tensor, nn
from diagonalize import get_diag_covar_in, get_diag_covar_out, get_diag_covar_inout, \
apply_transformation_in, apply_transformation_out, apply_transformation_inout
from icefall.utils import make_pad_mask
@ -128,11 +131,24 @@ class Conformer(EncoderInterface):
return x, lengths
def diagonalize(self) -> None:
# currently only diagonalize the self-attention modules, but could in principle
# do more layers.
# This oly diagonalize the self-attention modules, to diagonalize the embedding
# space call diagonalize() from class Transformer in model.py.
for m in self.encoder.layers:
m.self_attn.diagonalize()
@torch.no_grad()
def get_diag_covar_out(self) -> Tensor:
return (self.encoder_embed.get_diag_covar_out() +
sum([l.get_diag_covar_inout() for l in self.encoder.layers]))
@torch.no_grad()
def apply_transformation_out(self, t: Tensor) -> None:
self.encoder_embed.apply_transformation_out(t)
for l in self.encoder.layers:
l.apply_transformation_inout(t)
class ConformerEncoderLayer(nn.Module):
"""
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
@ -265,6 +281,20 @@ class ConformerEncoderLayer(nn.Module):
return src
@torch.no_grad()
def get_diag_covar_inout(self) -> Tensor:
return (get_diag_covar_inout(self.feed_forward) +
get_diag_covar_inout(self.feed_forward_macaron) +
self.self_attn.get_diag_covar_inout() +
self.conv_module.get_diag_covar_inout())
@torch.no_grad()
def apply_transformation_inout(self, t: Tensor) -> None:
apply_transformation_inout(self.feed_forward, t)
apply_transformation_inout(self.feed_forward_macaron, t)
self.self_attn.apply_transformation_inout(t)
self.conv_module.apply_transformation_inout(t)
class ConformerEncoder(nn.Module):
r"""ConformerEncoder is a stack of N encoder layers
@ -862,17 +892,17 @@ class RelPositionMultiheadAttention(nn.Module):
return covar * (x.shape[0] / covar.trace())
def get_proj(*args) -> Tensor:
def get_transformation(*args) -> Tensor:
"""
Returns a covariance-diagonalizing projection that diagonalizes
the summed covariance from these two projections. If mat1,mat2
Returns a covariance-diagonalizing transformation that diagonalizes
the summed covariance from these two transformations. If mat1,mat2
are of shape (dim0, dim1), it's the (dim0, dim0) covariance,
that we diagonalize.
Args: mat1, mat2, etc., which should all be matrices of the same
shape (dim0, dim1)
Returns: a projection indexed (new_dim0, old_dim0), i.e. of
Returns: a transformation indexed (new_dim0, old_dim0), i.e. of
shape dim0 by dim0 but 1st index is the newly created indexes.
"""
cov = get_normalized_covar(args[0])
@ -889,23 +919,30 @@ class RelPositionMultiheadAttention(nn.Module):
logging.info("Diagonalizing query/key space")
for i in range(num_heads):
q,k,l,pos_u,pos_v = query_proj[i], key_proj[i], linear_pos_proj[i], self.pos_bias_u[i], self.pos_bias_v[i]
qk_proj = get_proj(q, k, l)
q[:] = torch.matmul(qk_proj, q)
k[:] = torch.matmul(qk_proj, k)
l[:] = torch.matmul(qk_proj, l)
pos_u[:] = torch.mv(qk_proj, pos_u)
pos_v[:] = torch.mv(qk_proj, pos_v)
qk_trans = get_transformation(q, k, l)
q[:] = torch.matmul(qk_trans, q)
k[:] = torch.matmul(qk_trans, k)
l[:] = torch.matmul(qk_trans, l)
pos_u[:] = torch.mv(qk_trans, pos_u)
pos_v[:] = torch.mv(qk_trans, pos_v)
# Now do the value space
logging.info("Diagonalizing value space")
for i in range(num_heads):
v, o = value_proj[i], out_proj[i]
v_proj = get_proj(v, o)
v[:] = torch.matmul(v_proj, v)
o[:] = torch.matmul(v_proj, o)
v_trans = get_transformation(v, o)
v[:] = torch.matmul(v_trans, v)
o[:] = torch.matmul(v_trans, o)
@torch.no_grad()
def get_diag_covar_inout(self) -> Tensor:
return (get_diag_covar_in(self.in_proj) +
get_diag_covar_out(self.out_proj))
@torch.no_grad()
def apply_transformation_inout(self, t: Tensor) -> None:
apply_transformation_in(self.in_proj, t)
apply_transformation_out(self.out_proj, t)
@ -1009,6 +1046,17 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1)
@torch.no_grad()
def get_diag_covar_inout(self) -> Tensor:
return (get_diag_covar_in(self.pointwise_conv1) +
get_diag_covar_out(self.pointwise_conv2))
@torch.no_grad()
def apply_transformation_inout(self, t: Tensor) -> None:
apply_transformation_in(self.pointwise_conv1, t)
apply_transformation_out(self.pointwise_conv2, t)
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
@ -1103,6 +1151,15 @@ class Conv2dSubsampling(nn.Module):
x = self.out_balancer(x)
return x
@torch.no_grad()
def get_diag_covar_out(self) -> Tensor:
return get_diag_covar_out(self.out)
@torch.no_grad()
def apply_transformation_out(self, t: Tensor) -> None:
apply_transformation_out(self.out, t)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)

View File

@ -59,7 +59,7 @@ def get_diag_covar_in(m: nn.Module) -> Tensor:
w = w.reshape(in_channels, -1)
return _get_normalized_covar(w) # (in_channels, in_channels)
elif isinstance(m, nn.Sequential):
return get_diag_covar_in(m[0])
return get_diag_covar_in(m[0], t)
else:
# some modules have this function; if not, at this point, it is an error.
return m.get_diag_covar_in()
@ -135,7 +135,7 @@ def apply_transformation_in(m: nn.Module, t: Tensor) -> None:
w = w.reshape(m.weight.shape) # (out_channels, in_channels, [1 or 2 kernel dims])
m.weight[:] = w
elif isinstance(m, nn.Sequential):
apply_transformation_in(m[0])
apply_transformation_in(m[0], t)
else:
# some modules have this function; if not, at this point, it is an error.
m.apply_transformation_in(t)
@ -167,7 +167,7 @@ def apply_transformation_out(m: nn.Module, t: Tensor) -> None:
if m.bias is not None:
m.bias[:] = torch.matmul(t, m.bias)
elif isinstance(m, nn.Sequential):
apply_transformation_out(m[-1])
apply_transformation_out(m[-1], t)
else:
# some modules have this function; if not, at this point, it is an error.
m.apply_transformation_out(t)
@ -193,12 +193,9 @@ def get_transformation(cov: Tensor) -> Tensor:
Returns: a transformation indexed (new_dim0, old_dim0), i.e. of
shape dim0 by dim0 but 1st index is the newly created indexes.
"""
cov = get_normalized_covar(args[0])
for a in args[1:]:
cov += get_normalized_covar(a)
old_diag_stddev = cov.diag().var().sqrt().item()
l, U = cov.symeig(eigenvectors=True)
new_diag_stddev = l.var().sqrt().item()
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
f"to {new_diag_stddev:.3e}")
f"to {new_diag_stddev:.3e}, max diag elem changed from {cov.diag().max().item():.2e} to {l[-1].item():.2e}")
return U.t() # U.t() is indexed (new_dim, old_dim)

View File

@ -21,7 +21,7 @@ from torch import Tensor
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from diagonalize import get_diag_covar_in
from diagonalize import get_diag_covar_in, apply_transformation_in, get_transformation, apply_transformation_in, apply_transformation_out
from icefall.utils import add_sos
@ -195,7 +195,73 @@ class Transducer(nn.Module):
return (simple_loss, pruned_loss)
def get_diag_covar_in(self) -> Tensor:
return (get_diag_covar_in(self.simple_am_proj) +
get_diag_covar_in(joiner.encoder_proj) +
def diagonalize(self) -> None:
self.encoder.diagonalize() # diagonalizes self_attn layers.
diag_covar = (get_diag_covar_in(self.simple_am_proj) +
get_diag_covar_in(self.joiner.encoder_proj) +
self.encoder.get_diag_covar_out())
t = get_transformation(diag_covar)
self.encoder.apply_transformation_out(t)
apply_transformation_in(self.simple_am_proj, t)
apply_transformation_in(self.joiner.encoder_proj, t)
def _test_model():
import logging
logging.getLogger().setLevel(logging.INFO)
from conformer import Conformer
from joiner import Joiner
from decoder import Decoder
feature_dim = 40
attention_dim = 256
encoder_dim = 512
decoder_dim = 513
joiner_dim = 514
vocab_size = 1000
encoder = Conformer(num_features=40,
subsampling_factor=4,
d_model=encoder_dim,
nhead=4,
dim_feedforward=512,
num_encoder_layers=4)
decoder = Decoder(
vocab_size=600,
decoder_dim=decoder_dim,
blank_id=0,
context_size=2)
joiner = Joiner(
encoder_dim=encoder_dim,
decoder_dim=decoder_dim,
joiner_dim=joiner_dim,
vocab_size=vocab_size)
model = Transducer(encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=encoder_dim,
decoder_dim=decoder_dim,
joiner_dim=joiner_dim,
vocab_size=vocab_size)
batch_size = 5
seq_len = 50
feats = torch.randn(batch_size, seq_len, feature_dim)
x_lens = torch.full((batch_size,), seq_len, dtype=torch.int64)
y = k2.ragged.create_ragged_tensor(torch.arange(5, dtype=torch.int32).reshape(1,5).expand(batch_size,5))
model.eval() # eval mode so it's not random.
(simple_loss1, pruned_loss1) = model(feats, x_lens, y)
model.diagonalize()
(simple_loss2, pruned_loss2) = model(feats, x_lens, y)
model.diagonalize()
print(f"simple_loss1 = {simple_loss1.mean().item()}, simple_loss2 = {simple_loss2.mean().item()}")
print(f"pruned_loss1 = {pruned_loss1.mean().item()}, pruned_loss2 = {pruned_loss2.mean().item()}")
if __name__ == '__main__':
_test_model()