mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Hopefully this finishes the full orthogonalization.
This commit is contained in:
parent
67f916e599
commit
d61c8aa3bc
@ -31,6 +31,9 @@ from scaling import (
|
||||
ScaledLinear,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from diagonalize import get_diag_covar_in, get_diag_covar_out, get_diag_covar_inout, \
|
||||
apply_transformation_in, apply_transformation_out, apply_transformation_inout
|
||||
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
@ -128,11 +131,24 @@ class Conformer(EncoderInterface):
|
||||
return x, lengths
|
||||
|
||||
def diagonalize(self) -> None:
|
||||
# currently only diagonalize the self-attention modules, but could in principle
|
||||
# do more layers.
|
||||
# This oly diagonalize the self-attention modules, to diagonalize the embedding
|
||||
# space call diagonalize() from class Transformer in model.py.
|
||||
for m in self.encoder.layers:
|
||||
m.self_attn.diagonalize()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_diag_covar_out(self) -> Tensor:
|
||||
return (self.encoder_embed.get_diag_covar_out() +
|
||||
sum([l.get_diag_covar_inout() for l in self.encoder.layers]))
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_transformation_out(self, t: Tensor) -> None:
|
||||
self.encoder_embed.apply_transformation_out(t)
|
||||
for l in self.encoder.layers:
|
||||
l.apply_transformation_inout(t)
|
||||
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||
@ -265,6 +281,20 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
return src
|
||||
|
||||
@torch.no_grad()
|
||||
def get_diag_covar_inout(self) -> Tensor:
|
||||
return (get_diag_covar_inout(self.feed_forward) +
|
||||
get_diag_covar_inout(self.feed_forward_macaron) +
|
||||
self.self_attn.get_diag_covar_inout() +
|
||||
self.conv_module.get_diag_covar_inout())
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_transformation_inout(self, t: Tensor) -> None:
|
||||
apply_transformation_inout(self.feed_forward, t)
|
||||
apply_transformation_inout(self.feed_forward_macaron, t)
|
||||
self.self_attn.apply_transformation_inout(t)
|
||||
self.conv_module.apply_transformation_inout(t)
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
r"""ConformerEncoder is a stack of N encoder layers
|
||||
@ -862,17 +892,17 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
return covar * (x.shape[0] / covar.trace())
|
||||
|
||||
|
||||
def get_proj(*args) -> Tensor:
|
||||
def get_transformation(*args) -> Tensor:
|
||||
"""
|
||||
Returns a covariance-diagonalizing projection that diagonalizes
|
||||
the summed covariance from these two projections. If mat1,mat2
|
||||
Returns a covariance-diagonalizing transformation that diagonalizes
|
||||
the summed covariance from these two transformations. If mat1,mat2
|
||||
are of shape (dim0, dim1), it's the (dim0, dim0) covariance,
|
||||
that we diagonalize.
|
||||
|
||||
Args: mat1, mat2, etc., which should all be matrices of the same
|
||||
shape (dim0, dim1)
|
||||
|
||||
Returns: a projection indexed (new_dim0, old_dim0), i.e. of
|
||||
Returns: a transformation indexed (new_dim0, old_dim0), i.e. of
|
||||
shape dim0 by dim0 but 1st index is the newly created indexes.
|
||||
"""
|
||||
cov = get_normalized_covar(args[0])
|
||||
@ -889,23 +919,30 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
logging.info("Diagonalizing query/key space")
|
||||
for i in range(num_heads):
|
||||
q,k,l,pos_u,pos_v = query_proj[i], key_proj[i], linear_pos_proj[i], self.pos_bias_u[i], self.pos_bias_v[i]
|
||||
qk_proj = get_proj(q, k, l)
|
||||
q[:] = torch.matmul(qk_proj, q)
|
||||
k[:] = torch.matmul(qk_proj, k)
|
||||
l[:] = torch.matmul(qk_proj, l)
|
||||
pos_u[:] = torch.mv(qk_proj, pos_u)
|
||||
pos_v[:] = torch.mv(qk_proj, pos_v)
|
||||
qk_trans = get_transformation(q, k, l)
|
||||
q[:] = torch.matmul(qk_trans, q)
|
||||
k[:] = torch.matmul(qk_trans, k)
|
||||
l[:] = torch.matmul(qk_trans, l)
|
||||
pos_u[:] = torch.mv(qk_trans, pos_u)
|
||||
pos_v[:] = torch.mv(qk_trans, pos_v)
|
||||
|
||||
# Now do the value space
|
||||
logging.info("Diagonalizing value space")
|
||||
for i in range(num_heads):
|
||||
v, o = value_proj[i], out_proj[i]
|
||||
v_proj = get_proj(v, o)
|
||||
v[:] = torch.matmul(v_proj, v)
|
||||
o[:] = torch.matmul(v_proj, o)
|
||||
|
||||
v_trans = get_transformation(v, o)
|
||||
v[:] = torch.matmul(v_trans, v)
|
||||
o[:] = torch.matmul(v_trans, o)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_diag_covar_inout(self) -> Tensor:
|
||||
return (get_diag_covar_in(self.in_proj) +
|
||||
get_diag_covar_out(self.out_proj))
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_transformation_inout(self, t: Tensor) -> None:
|
||||
apply_transformation_in(self.in_proj, t)
|
||||
apply_transformation_out(self.out_proj, t)
|
||||
|
||||
|
||||
|
||||
@ -1009,6 +1046,17 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_diag_covar_inout(self) -> Tensor:
|
||||
return (get_diag_covar_in(self.pointwise_conv1) +
|
||||
get_diag_covar_out(self.pointwise_conv2))
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_transformation_inout(self, t: Tensor) -> None:
|
||||
apply_transformation_in(self.pointwise_conv1, t)
|
||||
apply_transformation_out(self.pointwise_conv2, t)
|
||||
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
@ -1103,6 +1151,15 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def get_diag_covar_out(self) -> Tensor:
|
||||
return get_diag_covar_out(self.out)
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_transformation_out(self, t: Tensor) -> None:
|
||||
apply_transformation_out(self.out, t)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
@ -59,7 +59,7 @@ def get_diag_covar_in(m: nn.Module) -> Tensor:
|
||||
w = w.reshape(in_channels, -1)
|
||||
return _get_normalized_covar(w) # (in_channels, in_channels)
|
||||
elif isinstance(m, nn.Sequential):
|
||||
return get_diag_covar_in(m[0])
|
||||
return get_diag_covar_in(m[0], t)
|
||||
else:
|
||||
# some modules have this function; if not, at this point, it is an error.
|
||||
return m.get_diag_covar_in()
|
||||
@ -135,7 +135,7 @@ def apply_transformation_in(m: nn.Module, t: Tensor) -> None:
|
||||
w = w.reshape(m.weight.shape) # (out_channels, in_channels, [1 or 2 kernel dims])
|
||||
m.weight[:] = w
|
||||
elif isinstance(m, nn.Sequential):
|
||||
apply_transformation_in(m[0])
|
||||
apply_transformation_in(m[0], t)
|
||||
else:
|
||||
# some modules have this function; if not, at this point, it is an error.
|
||||
m.apply_transformation_in(t)
|
||||
@ -167,7 +167,7 @@ def apply_transformation_out(m: nn.Module, t: Tensor) -> None:
|
||||
if m.bias is not None:
|
||||
m.bias[:] = torch.matmul(t, m.bias)
|
||||
elif isinstance(m, nn.Sequential):
|
||||
apply_transformation_out(m[-1])
|
||||
apply_transformation_out(m[-1], t)
|
||||
else:
|
||||
# some modules have this function; if not, at this point, it is an error.
|
||||
m.apply_transformation_out(t)
|
||||
@ -193,12 +193,9 @@ def get_transformation(cov: Tensor) -> Tensor:
|
||||
Returns: a transformation indexed (new_dim0, old_dim0), i.e. of
|
||||
shape dim0 by dim0 but 1st index is the newly created indexes.
|
||||
"""
|
||||
cov = get_normalized_covar(args[0])
|
||||
for a in args[1:]:
|
||||
cov += get_normalized_covar(a)
|
||||
old_diag_stddev = cov.diag().var().sqrt().item()
|
||||
l, U = cov.symeig(eigenvectors=True)
|
||||
new_diag_stddev = l.var().sqrt().item()
|
||||
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
|
||||
f"to {new_diag_stddev:.3e}")
|
||||
f"to {new_diag_stddev:.3e}, max diag elem changed from {cov.diag().max().item():.2e} to {l[-1].item():.2e}")
|
||||
return U.t() # U.t() is indexed (new_dim, old_dim)
|
||||
|
@ -21,7 +21,7 @@ from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
from diagonalize import get_diag_covar_in
|
||||
from diagonalize import get_diag_covar_in, apply_transformation_in, get_transformation, apply_transformation_in, apply_transformation_out
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
@ -195,7 +195,73 @@ class Transducer(nn.Module):
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
|
||||
def get_diag_covar_in(self) -> Tensor:
|
||||
return (get_diag_covar_in(self.simple_am_proj) +
|
||||
get_diag_covar_in(joiner.encoder_proj) +
|
||||
|
||||
def diagonalize(self) -> None:
|
||||
self.encoder.diagonalize() # diagonalizes self_attn layers.
|
||||
|
||||
diag_covar = (get_diag_covar_in(self.simple_am_proj) +
|
||||
get_diag_covar_in(self.joiner.encoder_proj) +
|
||||
self.encoder.get_diag_covar_out())
|
||||
t = get_transformation(diag_covar)
|
||||
self.encoder.apply_transformation_out(t)
|
||||
apply_transformation_in(self.simple_am_proj, t)
|
||||
apply_transformation_in(self.joiner.encoder_proj, t)
|
||||
|
||||
|
||||
|
||||
def _test_model():
|
||||
import logging
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
from conformer import Conformer
|
||||
from joiner import Joiner
|
||||
from decoder import Decoder
|
||||
feature_dim = 40
|
||||
attention_dim = 256
|
||||
encoder_dim = 512
|
||||
decoder_dim = 513
|
||||
joiner_dim = 514
|
||||
vocab_size = 1000
|
||||
encoder = Conformer(num_features=40,
|
||||
subsampling_factor=4,
|
||||
d_model=encoder_dim,
|
||||
nhead=4,
|
||||
dim_feedforward=512,
|
||||
num_encoder_layers=4)
|
||||
decoder = Decoder(
|
||||
vocab_size=600,
|
||||
decoder_dim=decoder_dim,
|
||||
blank_id=0,
|
||||
context_size=2)
|
||||
joiner = Joiner(
|
||||
encoder_dim=encoder_dim,
|
||||
decoder_dim=decoder_dim,
|
||||
joiner_dim=joiner_dim,
|
||||
vocab_size=vocab_size)
|
||||
model = Transducer(encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=encoder_dim,
|
||||
decoder_dim=decoder_dim,
|
||||
joiner_dim=joiner_dim,
|
||||
vocab_size=vocab_size)
|
||||
|
||||
batch_size = 5
|
||||
seq_len = 50
|
||||
|
||||
feats = torch.randn(batch_size, seq_len, feature_dim)
|
||||
x_lens = torch.full((batch_size,), seq_len, dtype=torch.int64)
|
||||
y = k2.ragged.create_ragged_tensor(torch.arange(5, dtype=torch.int32).reshape(1,5).expand(batch_size,5))
|
||||
model.eval() # eval mode so it's not random.
|
||||
(simple_loss1, pruned_loss1) = model(feats, x_lens, y)
|
||||
model.diagonalize()
|
||||
(simple_loss2, pruned_loss2) = model(feats, x_lens, y)
|
||||
model.diagonalize()
|
||||
|
||||
print(f"simple_loss1 = {simple_loss1.mean().item()}, simple_loss2 = {simple_loss2.mean().item()}")
|
||||
print(f"pruned_loss1 = {pruned_loss1.mean().item()}, pruned_loss2 = {pruned_loss2.mean().item()}")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_model()
|
||||
|
Loading…
x
Reference in New Issue
Block a user