Hopefully this finishes the full orthogonalization.

This commit is contained in:
Daniel Povey 2022-05-16 19:18:12 +08:00
parent 67f916e599
commit d61c8aa3bc
3 changed files with 148 additions and 28 deletions

View File

@ -31,6 +31,9 @@ from scaling import (
ScaledLinear, ScaledLinear,
) )
from torch import Tensor, nn from torch import Tensor, nn
from diagonalize import get_diag_covar_in, get_diag_covar_out, get_diag_covar_inout, \
apply_transformation_in, apply_transformation_out, apply_transformation_inout
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
@ -128,11 +131,24 @@ class Conformer(EncoderInterface):
return x, lengths return x, lengths
def diagonalize(self) -> None: def diagonalize(self) -> None:
# currently only diagonalize the self-attention modules, but could in principle # This oly diagonalize the self-attention modules, to diagonalize the embedding
# do more layers. # space call diagonalize() from class Transformer in model.py.
for m in self.encoder.layers: for m in self.encoder.layers:
m.self_attn.diagonalize() m.self_attn.diagonalize()
@torch.no_grad()
def get_diag_covar_out(self) -> Tensor:
return (self.encoder_embed.get_diag_covar_out() +
sum([l.get_diag_covar_inout() for l in self.encoder.layers]))
@torch.no_grad()
def apply_transformation_out(self, t: Tensor) -> None:
self.encoder_embed.apply_transformation_out(t)
for l in self.encoder.layers:
l.apply_transformation_inout(t)
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
""" """
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
@ -265,6 +281,20 @@ class ConformerEncoderLayer(nn.Module):
return src return src
@torch.no_grad()
def get_diag_covar_inout(self) -> Tensor:
return (get_diag_covar_inout(self.feed_forward) +
get_diag_covar_inout(self.feed_forward_macaron) +
self.self_attn.get_diag_covar_inout() +
self.conv_module.get_diag_covar_inout())
@torch.no_grad()
def apply_transformation_inout(self, t: Tensor) -> None:
apply_transformation_inout(self.feed_forward, t)
apply_transformation_inout(self.feed_forward_macaron, t)
self.self_attn.apply_transformation_inout(t)
self.conv_module.apply_transformation_inout(t)
class ConformerEncoder(nn.Module): class ConformerEncoder(nn.Module):
r"""ConformerEncoder is a stack of N encoder layers r"""ConformerEncoder is a stack of N encoder layers
@ -862,17 +892,17 @@ class RelPositionMultiheadAttention(nn.Module):
return covar * (x.shape[0] / covar.trace()) return covar * (x.shape[0] / covar.trace())
def get_proj(*args) -> Tensor: def get_transformation(*args) -> Tensor:
""" """
Returns a covariance-diagonalizing projection that diagonalizes Returns a covariance-diagonalizing transformation that diagonalizes
the summed covariance from these two projections. If mat1,mat2 the summed covariance from these two transformations. If mat1,mat2
are of shape (dim0, dim1), it's the (dim0, dim0) covariance, are of shape (dim0, dim1), it's the (dim0, dim0) covariance,
that we diagonalize. that we diagonalize.
Args: mat1, mat2, etc., which should all be matrices of the same Args: mat1, mat2, etc., which should all be matrices of the same
shape (dim0, dim1) shape (dim0, dim1)
Returns: a projection indexed (new_dim0, old_dim0), i.e. of Returns: a transformation indexed (new_dim0, old_dim0), i.e. of
shape dim0 by dim0 but 1st index is the newly created indexes. shape dim0 by dim0 but 1st index is the newly created indexes.
""" """
cov = get_normalized_covar(args[0]) cov = get_normalized_covar(args[0])
@ -889,23 +919,30 @@ class RelPositionMultiheadAttention(nn.Module):
logging.info("Diagonalizing query/key space") logging.info("Diagonalizing query/key space")
for i in range(num_heads): for i in range(num_heads):
q,k,l,pos_u,pos_v = query_proj[i], key_proj[i], linear_pos_proj[i], self.pos_bias_u[i], self.pos_bias_v[i] q,k,l,pos_u,pos_v = query_proj[i], key_proj[i], linear_pos_proj[i], self.pos_bias_u[i], self.pos_bias_v[i]
qk_proj = get_proj(q, k, l) qk_trans = get_transformation(q, k, l)
q[:] = torch.matmul(qk_proj, q) q[:] = torch.matmul(qk_trans, q)
k[:] = torch.matmul(qk_proj, k) k[:] = torch.matmul(qk_trans, k)
l[:] = torch.matmul(qk_proj, l) l[:] = torch.matmul(qk_trans, l)
pos_u[:] = torch.mv(qk_proj, pos_u) pos_u[:] = torch.mv(qk_trans, pos_u)
pos_v[:] = torch.mv(qk_proj, pos_v) pos_v[:] = torch.mv(qk_trans, pos_v)
# Now do the value space # Now do the value space
logging.info("Diagonalizing value space") logging.info("Diagonalizing value space")
for i in range(num_heads): for i in range(num_heads):
v, o = value_proj[i], out_proj[i] v, o = value_proj[i], out_proj[i]
v_proj = get_proj(v, o) v_trans = get_transformation(v, o)
v[:] = torch.matmul(v_proj, v) v[:] = torch.matmul(v_trans, v)
o[:] = torch.matmul(v_proj, o) o[:] = torch.matmul(v_trans, o)
@torch.no_grad()
def get_diag_covar_inout(self) -> Tensor:
return (get_diag_covar_in(self.in_proj) +
get_diag_covar_out(self.out_proj))
@torch.no_grad()
def apply_transformation_inout(self, t: Tensor) -> None:
apply_transformation_in(self.in_proj, t)
apply_transformation_out(self.out_proj, t)
@ -1009,6 +1046,17 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1) return x.permute(2, 0, 1)
@torch.no_grad()
def get_diag_covar_inout(self) -> Tensor:
return (get_diag_covar_in(self.pointwise_conv1) +
get_diag_covar_out(self.pointwise_conv2))
@torch.no_grad()
def apply_transformation_inout(self, t: Tensor) -> None:
apply_transformation_in(self.pointwise_conv1, t)
apply_transformation_out(self.pointwise_conv2, t)
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length). """Convolutional 2D subsampling (to 1/4 length).
@ -1103,6 +1151,15 @@ class Conv2dSubsampling(nn.Module):
x = self.out_balancer(x) x = self.out_balancer(x)
return x return x
@torch.no_grad()
def get_diag_covar_out(self) -> Tensor:
return get_diag_covar_out(self.out)
@torch.no_grad()
def apply_transformation_out(self, t: Tensor) -> None:
apply_transformation_out(self.out, t)
if __name__ == "__main__": if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)

View File

@ -59,7 +59,7 @@ def get_diag_covar_in(m: nn.Module) -> Tensor:
w = w.reshape(in_channels, -1) w = w.reshape(in_channels, -1)
return _get_normalized_covar(w) # (in_channels, in_channels) return _get_normalized_covar(w) # (in_channels, in_channels)
elif isinstance(m, nn.Sequential): elif isinstance(m, nn.Sequential):
return get_diag_covar_in(m[0]) return get_diag_covar_in(m[0], t)
else: else:
# some modules have this function; if not, at this point, it is an error. # some modules have this function; if not, at this point, it is an error.
return m.get_diag_covar_in() return m.get_diag_covar_in()
@ -135,7 +135,7 @@ def apply_transformation_in(m: nn.Module, t: Tensor) -> None:
w = w.reshape(m.weight.shape) # (out_channels, in_channels, [1 or 2 kernel dims]) w = w.reshape(m.weight.shape) # (out_channels, in_channels, [1 or 2 kernel dims])
m.weight[:] = w m.weight[:] = w
elif isinstance(m, nn.Sequential): elif isinstance(m, nn.Sequential):
apply_transformation_in(m[0]) apply_transformation_in(m[0], t)
else: else:
# some modules have this function; if not, at this point, it is an error. # some modules have this function; if not, at this point, it is an error.
m.apply_transformation_in(t) m.apply_transformation_in(t)
@ -167,7 +167,7 @@ def apply_transformation_out(m: nn.Module, t: Tensor) -> None:
if m.bias is not None: if m.bias is not None:
m.bias[:] = torch.matmul(t, m.bias) m.bias[:] = torch.matmul(t, m.bias)
elif isinstance(m, nn.Sequential): elif isinstance(m, nn.Sequential):
apply_transformation_out(m[-1]) apply_transformation_out(m[-1], t)
else: else:
# some modules have this function; if not, at this point, it is an error. # some modules have this function; if not, at this point, it is an error.
m.apply_transformation_out(t) m.apply_transformation_out(t)
@ -193,12 +193,9 @@ def get_transformation(cov: Tensor) -> Tensor:
Returns: a transformation indexed (new_dim0, old_dim0), i.e. of Returns: a transformation indexed (new_dim0, old_dim0), i.e. of
shape dim0 by dim0 but 1st index is the newly created indexes. shape dim0 by dim0 but 1st index is the newly created indexes.
""" """
cov = get_normalized_covar(args[0])
for a in args[1:]:
cov += get_normalized_covar(a)
old_diag_stddev = cov.diag().var().sqrt().item() old_diag_stddev = cov.diag().var().sqrt().item()
l, U = cov.symeig(eigenvectors=True) l, U = cov.symeig(eigenvectors=True)
new_diag_stddev = l.var().sqrt().item() new_diag_stddev = l.var().sqrt().item()
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} " logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
f"to {new_diag_stddev:.3e}") f"to {new_diag_stddev:.3e}, max diag elem changed from {cov.diag().max().item():.2e} to {l[-1].item():.2e}")
return U.t() # U.t() is indexed (new_dim, old_dim) return U.t() # U.t() is indexed (new_dim, old_dim)

View File

@ -21,7 +21,7 @@ from torch import Tensor
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from diagonalize import get_diag_covar_in from diagonalize import get_diag_covar_in, apply_transformation_in, get_transformation, apply_transformation_in, apply_transformation_out
from icefall.utils import add_sos from icefall.utils import add_sos
@ -195,7 +195,73 @@ class Transducer(nn.Module):
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss)
def get_diag_covar_in(self) -> Tensor:
return (get_diag_covar_in(self.simple_am_proj) + def diagonalize(self) -> None:
get_diag_covar_in(joiner.encoder_proj) + self.encoder.diagonalize() # diagonalizes self_attn layers.
self.encoder.get_diag_covar_out())
diag_covar = (get_diag_covar_in(self.simple_am_proj) +
get_diag_covar_in(self.joiner.encoder_proj) +
self.encoder.get_diag_covar_out())
t = get_transformation(diag_covar)
self.encoder.apply_transformation_out(t)
apply_transformation_in(self.simple_am_proj, t)
apply_transformation_in(self.joiner.encoder_proj, t)
def _test_model():
import logging
logging.getLogger().setLevel(logging.INFO)
from conformer import Conformer
from joiner import Joiner
from decoder import Decoder
feature_dim = 40
attention_dim = 256
encoder_dim = 512
decoder_dim = 513
joiner_dim = 514
vocab_size = 1000
encoder = Conformer(num_features=40,
subsampling_factor=4,
d_model=encoder_dim,
nhead=4,
dim_feedforward=512,
num_encoder_layers=4)
decoder = Decoder(
vocab_size=600,
decoder_dim=decoder_dim,
blank_id=0,
context_size=2)
joiner = Joiner(
encoder_dim=encoder_dim,
decoder_dim=decoder_dim,
joiner_dim=joiner_dim,
vocab_size=vocab_size)
model = Transducer(encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=encoder_dim,
decoder_dim=decoder_dim,
joiner_dim=joiner_dim,
vocab_size=vocab_size)
batch_size = 5
seq_len = 50
feats = torch.randn(batch_size, seq_len, feature_dim)
x_lens = torch.full((batch_size,), seq_len, dtype=torch.int64)
y = k2.ragged.create_ragged_tensor(torch.arange(5, dtype=torch.int32).reshape(1,5).expand(batch_size,5))
model.eval() # eval mode so it's not random.
(simple_loss1, pruned_loss1) = model(feats, x_lens, y)
model.diagonalize()
(simple_loss2, pruned_loss2) = model(feats, x_lens, y)
model.diagonalize()
print(f"simple_loss1 = {simple_loss1.mean().item()}, simple_loss2 = {simple_loss2.mean().item()}")
print(f"pruned_loss1 = {pruned_loss1.mean().item()}, pruned_loss2 = {pruned_loss2.mean().item()}")
if __name__ == '__main__':
_test_model()