diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py index 07cc55895..7af83d4fe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py @@ -32,7 +32,8 @@ from scaling import ( ) from torch import Tensor, nn from diagonalize import get_diag_covar_in, get_diag_covar_out, get_diag_covar_inout, \ - apply_transformation_in, apply_transformation_out, apply_transformation_inout + apply_transformation_in, apply_transformation_out, apply_transformation_inout, \ + OrthogonalTransformation from icefall.utils import make_pad_mask @@ -179,6 +180,8 @@ class ConformerEncoderLayer(nn.Module): ) -> None: super(ConformerEncoderLayer, self).__init__() + self.orth = OrthogonalTransformation(d_model) # not trainable; used in re-diagonalizing features. + self.layer_dropout = layer_dropout self.d_model = d_model @@ -240,6 +243,7 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ + src = self.orth(src) src_orig = src warmup_scale = min(0.1 + warmup, 1.0) @@ -288,14 +292,29 @@ class ConformerEncoderLayer(nn.Module): self.self_attn.get_diag_covar_inout() + self.conv_module.get_diag_covar_inout()) + @torch.no_grad() - def apply_transformation_inout(self, t: Tensor) -> None: + def apply_transformation_in(self, t: Tensor) -> None: + """ + Rotate only the input feature space with an orthogonal matrix. + t is indexed (new_channel_dim, old_channel_dim) + """ + self.orth.apply_transformation_in(t) + + @torch.no_grad() + def apply_transformation_out(self, t: Tensor) -> None: + self.orth.apply_transformation_out(t) apply_transformation_inout(self.feed_forward, t) apply_transformation_inout(self.feed_forward_macaron, t) self.self_attn.apply_transformation_inout(t) self.conv_module.apply_transformation_inout(t) + @torch.no_grad() + def get_transformation_out(self) -> Tensor: + return self.orth.get_transformation_out() + + class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py index 2e4e2ee9e..1da066e54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py @@ -199,3 +199,64 @@ def get_transformation(cov: Tensor) -> Tensor: logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} " f"to {new_diag_stddev:.3e}, max diag elem changed from {cov.diag().max().item():.2e} to {l[-1].item():.2e}") return U.t() # U.t() is indexed (new_dim, old_dim) + +class OrthogonalTransformation(nn.Module): + + def __init__(self, num_channels: int): + super(OrthogonalTransformation, self).__init__() + # `weight` is indexed (channel_out, channel_in) + self.register_buffer('weight', torch.eye(num_channels)) # not a parameter + + self.register_buffer('feats_cov', torch.eye(num_channels)) # not a parameter + + self.step = 0 # just to co-ordinate updating feats_cov every 10 batches; not saved to disk. + self.beta = 0.9 # affects how long we remember the stats. not super critical. + + def forward(self, x: Tensor): + """ + Args: + x: Tensor of shape (*, num_channel) + Returns: + Tensor of shape (*, num_channels), x multiplied by orthogonal matrix. + """ + x = torch.matmul(x, self.weight.t()) + if self.step % 10 == 0 and self.train(): + # store covariance after input transform. + # Update covariance stats every 10 batches (in training mode) + f = x.reshape(-1, x.shape[-1]) + cov = torch.matmul(f.t(), f) # channel_dim by channel_dim + self.feats_cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) + self.step += 1 + return x + + @torch.no_grad() + def apply_transformation_in(self, t: Tensor) -> None: + """ + Rotate only the input feature space with an orthogonal matrix. + t is indexed (new_channel_dim, old_channel_dim) + """ + # note, self.weight is indexed (channel_out, channel_in), interpreted + # initially as (channel_out, old_channel_in), which we multiply + # by t.t() which is (old_channel_in, new_channel_in) + self.weight[:] = torch.matmul(self.weight, t.t()) + + @torch.no_grad() + def apply_transformation_out(self, t: Tensor) -> None: + """ + Rotate only the output feature space with an orthogonal matrix. + t is indexed (new_channel_dim, old_channel_dim) + + We don't bother updating the covariance stats; they will decay. + """ + # note, self.weight is indexed (channel_out, channel_in), interpreted + # initially as (old_channel_out, old_channe), which we pre-multiply + # by t which is (new_channel_out, old_channel_out) + self.weight[:] = torch.matmul(t, self.weight) + self.feats_cov[:] = torch.matmul(t, torch.matmul(self.feats_cov, t.t())) + + + @torch.no_grad() + def get_transformation_out(self) -> Tensor: + # see also get_transformation() above for notes on this. + cov = self.feats_cov + return get_transformation(cov) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py index 1bb74adf7..243288b99 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py @@ -197,15 +197,18 @@ class Transducer(nn.Module): def diagonalize(self) -> None: - self.encoder.diagonalize() # diagonalizes self_attn layers. + cur_transform = None + for l in self.encoder.encoder.layers: + if cur_transform is not None: + l.apply_transformation_in(cur_transform) + cur_transform = l.get_transformation_out() + l.apply_transformation_out(cur_transform) - diag_covar = (get_diag_covar_in(self.simple_am_proj) + - get_diag_covar_in(self.joiner.encoder_proj) + - self.encoder.get_diag_covar_out()) - t = get_transformation(diag_covar) - self.encoder.apply_transformation_out(t) - apply_transformation_in(self.simple_am_proj, t) - apply_transformation_in(self.joiner.encoder_proj, t) + self.encoder.diagonalize() # diagonalizes self_attn layers, this is + # purely internal to the self_attn layers. + + apply_transformation_in(self.simple_am_proj, cur_transform) + apply_transformation_in(self.joiner.encoder_proj, cur_transform) @@ -255,11 +258,11 @@ def _test_model(): (simple_loss1, pruned_loss1) = model(feats, x_lens, y) model.diagonalize() (simple_loss2, pruned_loss2) = model(feats, x_lens, y) - model.diagonalize() - print(f"simple_loss1 = {simple_loss1.mean().item()}, simple_loss2 = {simple_loss2.mean().item()}") print(f"pruned_loss1 = {pruned_loss1.mean().item()}, pruned_loss2 = {pruned_loss2.mean().item()}") + model.diagonalize() +