mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Finish feat-diagonalizing code
This commit is contained in:
parent
2f2934a115
commit
07d3369234
@ -32,7 +32,8 @@ from scaling import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from diagonalize import get_diag_covar_in, get_diag_covar_out, get_diag_covar_inout, \
|
||||
apply_transformation_in, apply_transformation_out, apply_transformation_inout
|
||||
apply_transformation_in, apply_transformation_out, apply_transformation_inout, \
|
||||
OrthogonalTransformation
|
||||
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
@ -179,6 +180,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
|
||||
self.orth = OrthogonalTransformation(d_model) # not trainable; used in re-diagonalizing features.
|
||||
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
self.d_model = d_model
|
||||
@ -240,6 +243,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
src = self.orth(src)
|
||||
src_orig = src
|
||||
|
||||
warmup_scale = min(0.1 + warmup, 1.0)
|
||||
@ -288,14 +292,29 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.self_attn.get_diag_covar_inout() +
|
||||
self.conv_module.get_diag_covar_inout())
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_transformation_inout(self, t: Tensor) -> None:
|
||||
def apply_transformation_in(self, t: Tensor) -> None:
|
||||
"""
|
||||
Rotate only the input feature space with an orthogonal matrix.
|
||||
t is indexed (new_channel_dim, old_channel_dim)
|
||||
"""
|
||||
self.orth.apply_transformation_in(t)
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_transformation_out(self, t: Tensor) -> None:
|
||||
self.orth.apply_transformation_out(t)
|
||||
apply_transformation_inout(self.feed_forward, t)
|
||||
apply_transformation_inout(self.feed_forward_macaron, t)
|
||||
self.self_attn.apply_transformation_inout(t)
|
||||
self.conv_module.apply_transformation_inout(t)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_transformation_out(self) -> Tensor:
|
||||
return self.orth.get_transformation_out()
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
r"""ConformerEncoder is a stack of N encoder layers
|
||||
|
||||
|
@ -199,3 +199,64 @@ def get_transformation(cov: Tensor) -> Tensor:
|
||||
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
|
||||
f"to {new_diag_stddev:.3e}, max diag elem changed from {cov.diag().max().item():.2e} to {l[-1].item():.2e}")
|
||||
return U.t() # U.t() is indexed (new_dim, old_dim)
|
||||
|
||||
class OrthogonalTransformation(nn.Module):
|
||||
|
||||
def __init__(self, num_channels: int):
|
||||
super(OrthogonalTransformation, self).__init__()
|
||||
# `weight` is indexed (channel_out, channel_in)
|
||||
self.register_buffer('weight', torch.eye(num_channels)) # not a parameter
|
||||
|
||||
self.register_buffer('feats_cov', torch.eye(num_channels)) # not a parameter
|
||||
|
||||
self.step = 0 # just to co-ordinate updating feats_cov every 10 batches; not saved to disk.
|
||||
self.beta = 0.9 # affects how long we remember the stats. not super critical.
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
Args:
|
||||
x: Tensor of shape (*, num_channel)
|
||||
Returns:
|
||||
Tensor of shape (*, num_channels), x multiplied by orthogonal matrix.
|
||||
"""
|
||||
x = torch.matmul(x, self.weight.t())
|
||||
if self.step % 10 == 0 and self.train():
|
||||
# store covariance after input transform.
|
||||
# Update covariance stats every 10 batches (in training mode)
|
||||
f = x.reshape(-1, x.shape[-1])
|
||||
cov = torch.matmul(f.t(), f) # channel_dim by channel_dim
|
||||
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||
self.step += 1
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_transformation_in(self, t: Tensor) -> None:
|
||||
"""
|
||||
Rotate only the input feature space with an orthogonal matrix.
|
||||
t is indexed (new_channel_dim, old_channel_dim)
|
||||
"""
|
||||
# note, self.weight is indexed (channel_out, channel_in), interpreted
|
||||
# initially as (channel_out, old_channel_in), which we multiply
|
||||
# by t.t() which is (old_channel_in, new_channel_in)
|
||||
self.weight[:] = torch.matmul(self.weight, t.t())
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_transformation_out(self, t: Tensor) -> None:
|
||||
"""
|
||||
Rotate only the output feature space with an orthogonal matrix.
|
||||
t is indexed (new_channel_dim, old_channel_dim)
|
||||
|
||||
We don't bother updating the covariance stats; they will decay.
|
||||
"""
|
||||
# note, self.weight is indexed (channel_out, channel_in), interpreted
|
||||
# initially as (old_channel_out, old_channe), which we pre-multiply
|
||||
# by t which is (new_channel_out, old_channel_out)
|
||||
self.weight[:] = torch.matmul(t, self.weight)
|
||||
self.feats_cov[:] = torch.matmul(t, torch.matmul(self.feats_cov, t.t()))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_transformation_out(self) -> Tensor:
|
||||
# see also get_transformation() above for notes on this.
|
||||
cov = self.feats_cov
|
||||
return get_transformation(cov)
|
||||
|
@ -197,15 +197,18 @@ class Transducer(nn.Module):
|
||||
|
||||
|
||||
def diagonalize(self) -> None:
|
||||
self.encoder.diagonalize() # diagonalizes self_attn layers.
|
||||
cur_transform = None
|
||||
for l in self.encoder.encoder.layers:
|
||||
if cur_transform is not None:
|
||||
l.apply_transformation_in(cur_transform)
|
||||
cur_transform = l.get_transformation_out()
|
||||
l.apply_transformation_out(cur_transform)
|
||||
|
||||
diag_covar = (get_diag_covar_in(self.simple_am_proj) +
|
||||
get_diag_covar_in(self.joiner.encoder_proj) +
|
||||
self.encoder.get_diag_covar_out())
|
||||
t = get_transformation(diag_covar)
|
||||
self.encoder.apply_transformation_out(t)
|
||||
apply_transformation_in(self.simple_am_proj, t)
|
||||
apply_transformation_in(self.joiner.encoder_proj, t)
|
||||
self.encoder.diagonalize() # diagonalizes self_attn layers, this is
|
||||
# purely internal to the self_attn layers.
|
||||
|
||||
apply_transformation_in(self.simple_am_proj, cur_transform)
|
||||
apply_transformation_in(self.joiner.encoder_proj, cur_transform)
|
||||
|
||||
|
||||
|
||||
@ -255,11 +258,11 @@ def _test_model():
|
||||
(simple_loss1, pruned_loss1) = model(feats, x_lens, y)
|
||||
model.diagonalize()
|
||||
(simple_loss2, pruned_loss2) = model(feats, x_lens, y)
|
||||
model.diagonalize()
|
||||
|
||||
print(f"simple_loss1 = {simple_loss1.mean().item()}, simple_loss2 = {simple_loss2.mean().item()}")
|
||||
print(f"pruned_loss1 = {pruned_loss1.mean().item()}, pruned_loss2 = {pruned_loss2.mean().item()}")
|
||||
|
||||
model.diagonalize()
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user