Finish feat-diagonalizing code

This commit is contained in:
Daniel Povey 2022-05-17 14:13:56 +08:00
parent 2f2934a115
commit 07d3369234
3 changed files with 95 additions and 12 deletions

View File

@ -32,7 +32,8 @@ from scaling import (
)
from torch import Tensor, nn
from diagonalize import get_diag_covar_in, get_diag_covar_out, get_diag_covar_inout, \
apply_transformation_in, apply_transformation_out, apply_transformation_inout
apply_transformation_in, apply_transformation_out, apply_transformation_inout, \
OrthogonalTransformation
from icefall.utils import make_pad_mask
@ -179,6 +180,8 @@ class ConformerEncoderLayer(nn.Module):
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.orth = OrthogonalTransformation(d_model) # not trainable; used in re-diagonalizing features.
self.layer_dropout = layer_dropout
self.d_model = d_model
@ -240,6 +243,7 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
src = self.orth(src)
src_orig = src
warmup_scale = min(0.1 + warmup, 1.0)
@ -288,14 +292,29 @@ class ConformerEncoderLayer(nn.Module):
self.self_attn.get_diag_covar_inout() +
self.conv_module.get_diag_covar_inout())
@torch.no_grad()
def apply_transformation_inout(self, t: Tensor) -> None:
def apply_transformation_in(self, t: Tensor) -> None:
"""
Rotate only the input feature space with an orthogonal matrix.
t is indexed (new_channel_dim, old_channel_dim)
"""
self.orth.apply_transformation_in(t)
@torch.no_grad()
def apply_transformation_out(self, t: Tensor) -> None:
self.orth.apply_transformation_out(t)
apply_transformation_inout(self.feed_forward, t)
apply_transformation_inout(self.feed_forward_macaron, t)
self.self_attn.apply_transformation_inout(t)
self.conv_module.apply_transformation_inout(t)
@torch.no_grad()
def get_transformation_out(self) -> Tensor:
return self.orth.get_transformation_out()
class ConformerEncoder(nn.Module):
r"""ConformerEncoder is a stack of N encoder layers

View File

@ -199,3 +199,64 @@ def get_transformation(cov: Tensor) -> Tensor:
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
f"to {new_diag_stddev:.3e}, max diag elem changed from {cov.diag().max().item():.2e} to {l[-1].item():.2e}")
return U.t() # U.t() is indexed (new_dim, old_dim)
class OrthogonalTransformation(nn.Module):
def __init__(self, num_channels: int):
super(OrthogonalTransformation, self).__init__()
# `weight` is indexed (channel_out, channel_in)
self.register_buffer('weight', torch.eye(num_channels)) # not a parameter
self.register_buffer('feats_cov', torch.eye(num_channels)) # not a parameter
self.step = 0 # just to co-ordinate updating feats_cov every 10 batches; not saved to disk.
self.beta = 0.9 # affects how long we remember the stats. not super critical.
def forward(self, x: Tensor):
"""
Args:
x: Tensor of shape (*, num_channel)
Returns:
Tensor of shape (*, num_channels), x multiplied by orthogonal matrix.
"""
x = torch.matmul(x, self.weight.t())
if self.step % 10 == 0 and self.train():
# store covariance after input transform.
# Update covariance stats every 10 batches (in training mode)
f = x.reshape(-1, x.shape[-1])
cov = torch.matmul(f.t(), f) # channel_dim by channel_dim
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1
return x
@torch.no_grad()
def apply_transformation_in(self, t: Tensor) -> None:
"""
Rotate only the input feature space with an orthogonal matrix.
t is indexed (new_channel_dim, old_channel_dim)
"""
# note, self.weight is indexed (channel_out, channel_in), interpreted
# initially as (channel_out, old_channel_in), which we multiply
# by t.t() which is (old_channel_in, new_channel_in)
self.weight[:] = torch.matmul(self.weight, t.t())
@torch.no_grad()
def apply_transformation_out(self, t: Tensor) -> None:
"""
Rotate only the output feature space with an orthogonal matrix.
t is indexed (new_channel_dim, old_channel_dim)
We don't bother updating the covariance stats; they will decay.
"""
# note, self.weight is indexed (channel_out, channel_in), interpreted
# initially as (old_channel_out, old_channe), which we pre-multiply
# by t which is (new_channel_out, old_channel_out)
self.weight[:] = torch.matmul(t, self.weight)
self.feats_cov[:] = torch.matmul(t, torch.matmul(self.feats_cov, t.t()))
@torch.no_grad()
def get_transformation_out(self) -> Tensor:
# see also get_transformation() above for notes on this.
cov = self.feats_cov
return get_transformation(cov)

View File

@ -197,15 +197,18 @@ class Transducer(nn.Module):
def diagonalize(self) -> None:
self.encoder.diagonalize() # diagonalizes self_attn layers.
cur_transform = None
for l in self.encoder.encoder.layers:
if cur_transform is not None:
l.apply_transformation_in(cur_transform)
cur_transform = l.get_transformation_out()
l.apply_transformation_out(cur_transform)
diag_covar = (get_diag_covar_in(self.simple_am_proj) +
get_diag_covar_in(self.joiner.encoder_proj) +
self.encoder.get_diag_covar_out())
t = get_transformation(diag_covar)
self.encoder.apply_transformation_out(t)
apply_transformation_in(self.simple_am_proj, t)
apply_transformation_in(self.joiner.encoder_proj, t)
self.encoder.diagonalize() # diagonalizes self_attn layers, this is
# purely internal to the self_attn layers.
apply_transformation_in(self.simple_am_proj, cur_transform)
apply_transformation_in(self.joiner.encoder_proj, cur_transform)
@ -255,11 +258,11 @@ def _test_model():
(simple_loss1, pruned_loss1) = model(feats, x_lens, y)
model.diagonalize()
(simple_loss2, pruned_loss2) = model(feats, x_lens, y)
model.diagonalize()
print(f"simple_loss1 = {simple_loss1.mean().item()}, simple_loss2 = {simple_loss2.mean().item()}")
print(f"pruned_loss1 = {pruned_loss1.mean().item()}, pruned_loss2 = {pruned_loss2.mean().item()}")
model.diagonalize()