mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Finish feat-diagonalizing code
This commit is contained in:
parent
2f2934a115
commit
07d3369234
@ -32,7 +32,8 @@ from scaling import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from diagonalize import get_diag_covar_in, get_diag_covar_out, get_diag_covar_inout, \
|
from diagonalize import get_diag_covar_in, get_diag_covar_out, get_diag_covar_inout, \
|
||||||
apply_transformation_in, apply_transformation_out, apply_transformation_inout
|
apply_transformation_in, apply_transformation_out, apply_transformation_inout, \
|
||||||
|
OrthogonalTransformation
|
||||||
|
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
@ -179,6 +180,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
|
self.orth = OrthogonalTransformation(d_model) # not trainable; used in re-diagonalizing features.
|
||||||
|
|
||||||
self.layer_dropout = layer_dropout
|
self.layer_dropout = layer_dropout
|
||||||
|
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@ -240,6 +243,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
"""
|
"""
|
||||||
|
src = self.orth(src)
|
||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
warmup_scale = min(0.1 + warmup, 1.0)
|
warmup_scale = min(0.1 + warmup, 1.0)
|
||||||
@ -288,14 +292,29 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.self_attn.get_diag_covar_inout() +
|
self.self_attn.get_diag_covar_inout() +
|
||||||
self.conv_module.get_diag_covar_inout())
|
self.conv_module.get_diag_covar_inout())
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def apply_transformation_inout(self, t: Tensor) -> None:
|
def apply_transformation_in(self, t: Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Rotate only the input feature space with an orthogonal matrix.
|
||||||
|
t is indexed (new_channel_dim, old_channel_dim)
|
||||||
|
"""
|
||||||
|
self.orth.apply_transformation_in(t)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def apply_transformation_out(self, t: Tensor) -> None:
|
||||||
|
self.orth.apply_transformation_out(t)
|
||||||
apply_transformation_inout(self.feed_forward, t)
|
apply_transformation_inout(self.feed_forward, t)
|
||||||
apply_transformation_inout(self.feed_forward_macaron, t)
|
apply_transformation_inout(self.feed_forward_macaron, t)
|
||||||
self.self_attn.apply_transformation_inout(t)
|
self.self_attn.apply_transformation_inout(t)
|
||||||
self.conv_module.apply_transformation_inout(t)
|
self.conv_module.apply_transformation_inout(t)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_transformation_out(self) -> Tensor:
|
||||||
|
return self.orth.get_transformation_out()
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoder(nn.Module):
|
class ConformerEncoder(nn.Module):
|
||||||
r"""ConformerEncoder is a stack of N encoder layers
|
r"""ConformerEncoder is a stack of N encoder layers
|
||||||
|
|
||||||
|
@ -199,3 +199,64 @@ def get_transformation(cov: Tensor) -> Tensor:
|
|||||||
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
|
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
|
||||||
f"to {new_diag_stddev:.3e}, max diag elem changed from {cov.diag().max().item():.2e} to {l[-1].item():.2e}")
|
f"to {new_diag_stddev:.3e}, max diag elem changed from {cov.diag().max().item():.2e} to {l[-1].item():.2e}")
|
||||||
return U.t() # U.t() is indexed (new_dim, old_dim)
|
return U.t() # U.t() is indexed (new_dim, old_dim)
|
||||||
|
|
||||||
|
class OrthogonalTransformation(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_channels: int):
|
||||||
|
super(OrthogonalTransformation, self).__init__()
|
||||||
|
# `weight` is indexed (channel_out, channel_in)
|
||||||
|
self.register_buffer('weight', torch.eye(num_channels)) # not a parameter
|
||||||
|
|
||||||
|
self.register_buffer('feats_cov', torch.eye(num_channels)) # not a parameter
|
||||||
|
|
||||||
|
self.step = 0 # just to co-ordinate updating feats_cov every 10 batches; not saved to disk.
|
||||||
|
self.beta = 0.9 # affects how long we remember the stats. not super critical.
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: Tensor of shape (*, num_channel)
|
||||||
|
Returns:
|
||||||
|
Tensor of shape (*, num_channels), x multiplied by orthogonal matrix.
|
||||||
|
"""
|
||||||
|
x = torch.matmul(x, self.weight.t())
|
||||||
|
if self.step % 10 == 0 and self.train():
|
||||||
|
# store covariance after input transform.
|
||||||
|
# Update covariance stats every 10 batches (in training mode)
|
||||||
|
f = x.reshape(-1, x.shape[-1])
|
||||||
|
cov = torch.matmul(f.t(), f) # channel_dim by channel_dim
|
||||||
|
self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||||
|
self.step += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def apply_transformation_in(self, t: Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Rotate only the input feature space with an orthogonal matrix.
|
||||||
|
t is indexed (new_channel_dim, old_channel_dim)
|
||||||
|
"""
|
||||||
|
# note, self.weight is indexed (channel_out, channel_in), interpreted
|
||||||
|
# initially as (channel_out, old_channel_in), which we multiply
|
||||||
|
# by t.t() which is (old_channel_in, new_channel_in)
|
||||||
|
self.weight[:] = torch.matmul(self.weight, t.t())
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def apply_transformation_out(self, t: Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Rotate only the output feature space with an orthogonal matrix.
|
||||||
|
t is indexed (new_channel_dim, old_channel_dim)
|
||||||
|
|
||||||
|
We don't bother updating the covariance stats; they will decay.
|
||||||
|
"""
|
||||||
|
# note, self.weight is indexed (channel_out, channel_in), interpreted
|
||||||
|
# initially as (old_channel_out, old_channe), which we pre-multiply
|
||||||
|
# by t which is (new_channel_out, old_channel_out)
|
||||||
|
self.weight[:] = torch.matmul(t, self.weight)
|
||||||
|
self.feats_cov[:] = torch.matmul(t, torch.matmul(self.feats_cov, t.t()))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_transformation_out(self) -> Tensor:
|
||||||
|
# see also get_transformation() above for notes on this.
|
||||||
|
cov = self.feats_cov
|
||||||
|
return get_transformation(cov)
|
||||||
|
@ -197,15 +197,18 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def diagonalize(self) -> None:
|
def diagonalize(self) -> None:
|
||||||
self.encoder.diagonalize() # diagonalizes self_attn layers.
|
cur_transform = None
|
||||||
|
for l in self.encoder.encoder.layers:
|
||||||
|
if cur_transform is not None:
|
||||||
|
l.apply_transformation_in(cur_transform)
|
||||||
|
cur_transform = l.get_transformation_out()
|
||||||
|
l.apply_transformation_out(cur_transform)
|
||||||
|
|
||||||
diag_covar = (get_diag_covar_in(self.simple_am_proj) +
|
self.encoder.diagonalize() # diagonalizes self_attn layers, this is
|
||||||
get_diag_covar_in(self.joiner.encoder_proj) +
|
# purely internal to the self_attn layers.
|
||||||
self.encoder.get_diag_covar_out())
|
|
||||||
t = get_transformation(diag_covar)
|
apply_transformation_in(self.simple_am_proj, cur_transform)
|
||||||
self.encoder.apply_transformation_out(t)
|
apply_transformation_in(self.joiner.encoder_proj, cur_transform)
|
||||||
apply_transformation_in(self.simple_am_proj, t)
|
|
||||||
apply_transformation_in(self.joiner.encoder_proj, t)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -255,11 +258,11 @@ def _test_model():
|
|||||||
(simple_loss1, pruned_loss1) = model(feats, x_lens, y)
|
(simple_loss1, pruned_loss1) = model(feats, x_lens, y)
|
||||||
model.diagonalize()
|
model.diagonalize()
|
||||||
(simple_loss2, pruned_loss2) = model(feats, x_lens, y)
|
(simple_loss2, pruned_loss2) = model(feats, x_lens, y)
|
||||||
model.diagonalize()
|
|
||||||
|
|
||||||
print(f"simple_loss1 = {simple_loss1.mean().item()}, simple_loss2 = {simple_loss2.mean().item()}")
|
print(f"simple_loss1 = {simple_loss1.mean().item()}, simple_loss2 = {simple_loss2.mean().item()}")
|
||||||
print(f"pruned_loss1 = {pruned_loss1.mean().item()}, pruned_loss2 = {pruned_loss2.mean().item()}")
|
print(f"pruned_loss1 = {pruned_loss1.mean().item()}, pruned_loss2 = {pruned_loss2.mean().item()}")
|
||||||
|
|
||||||
|
model.diagonalize()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user