diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py index 5ff132f14..07cc55895 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py @@ -31,6 +31,9 @@ from scaling import ( ScaledLinear, ) from torch import Tensor, nn +from diagonalize import get_diag_covar_in, get_diag_covar_out, get_diag_covar_inout, \ + apply_transformation_in, apply_transformation_out, apply_transformation_inout + from icefall.utils import make_pad_mask @@ -128,11 +131,24 @@ class Conformer(EncoderInterface): return x, lengths def diagonalize(self) -> None: - # currently only diagonalize the self-attention modules, but could in principle - # do more layers. + # This oly diagonalize the self-attention modules, to diagonalize the embedding + # space call diagonalize() from class Transformer in model.py. for m in self.encoder.layers: m.self_attn.diagonalize() + @torch.no_grad() + def get_diag_covar_out(self) -> Tensor: + return (self.encoder_embed.get_diag_covar_out() + + sum([l.get_diag_covar_inout() for l in self.encoder.layers])) + + @torch.no_grad() + def apply_transformation_out(self, t: Tensor) -> None: + self.encoder_embed.apply_transformation_out(t) + for l in self.encoder.layers: + l.apply_transformation_inout(t) + + + class ConformerEncoderLayer(nn.Module): """ ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. @@ -265,6 +281,20 @@ class ConformerEncoderLayer(nn.Module): return src + @torch.no_grad() + def get_diag_covar_inout(self) -> Tensor: + return (get_diag_covar_inout(self.feed_forward) + + get_diag_covar_inout(self.feed_forward_macaron) + + self.self_attn.get_diag_covar_inout() + + self.conv_module.get_diag_covar_inout()) + + @torch.no_grad() + def apply_transformation_inout(self, t: Tensor) -> None: + apply_transformation_inout(self.feed_forward, t) + apply_transformation_inout(self.feed_forward_macaron, t) + self.self_attn.apply_transformation_inout(t) + self.conv_module.apply_transformation_inout(t) + class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers @@ -862,17 +892,17 @@ class RelPositionMultiheadAttention(nn.Module): return covar * (x.shape[0] / covar.trace()) - def get_proj(*args) -> Tensor: + def get_transformation(*args) -> Tensor: """ - Returns a covariance-diagonalizing projection that diagonalizes - the summed covariance from these two projections. If mat1,mat2 + Returns a covariance-diagonalizing transformation that diagonalizes + the summed covariance from these two transformations. If mat1,mat2 are of shape (dim0, dim1), it's the (dim0, dim0) covariance, that we diagonalize. Args: mat1, mat2, etc., which should all be matrices of the same shape (dim0, dim1) - Returns: a projection indexed (new_dim0, old_dim0), i.e. of + Returns: a transformation indexed (new_dim0, old_dim0), i.e. of shape dim0 by dim0 but 1st index is the newly created indexes. """ cov = get_normalized_covar(args[0]) @@ -889,23 +919,30 @@ class RelPositionMultiheadAttention(nn.Module): logging.info("Diagonalizing query/key space") for i in range(num_heads): q,k,l,pos_u,pos_v = query_proj[i], key_proj[i], linear_pos_proj[i], self.pos_bias_u[i], self.pos_bias_v[i] - qk_proj = get_proj(q, k, l) - q[:] = torch.matmul(qk_proj, q) - k[:] = torch.matmul(qk_proj, k) - l[:] = torch.matmul(qk_proj, l) - pos_u[:] = torch.mv(qk_proj, pos_u) - pos_v[:] = torch.mv(qk_proj, pos_v) + qk_trans = get_transformation(q, k, l) + q[:] = torch.matmul(qk_trans, q) + k[:] = torch.matmul(qk_trans, k) + l[:] = torch.matmul(qk_trans, l) + pos_u[:] = torch.mv(qk_trans, pos_u) + pos_v[:] = torch.mv(qk_trans, pos_v) # Now do the value space logging.info("Diagonalizing value space") for i in range(num_heads): v, o = value_proj[i], out_proj[i] - v_proj = get_proj(v, o) - v[:] = torch.matmul(v_proj, v) - o[:] = torch.matmul(v_proj, o) - + v_trans = get_transformation(v, o) + v[:] = torch.matmul(v_trans, v) + o[:] = torch.matmul(v_trans, o) + @torch.no_grad() + def get_diag_covar_inout(self) -> Tensor: + return (get_diag_covar_in(self.in_proj) + + get_diag_covar_out(self.out_proj)) + @torch.no_grad() + def apply_transformation_inout(self, t: Tensor) -> None: + apply_transformation_in(self.in_proj, t) + apply_transformation_out(self.out_proj, t) @@ -1009,6 +1046,17 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) + @torch.no_grad() + def get_diag_covar_inout(self) -> Tensor: + return (get_diag_covar_in(self.pointwise_conv1) + + get_diag_covar_out(self.pointwise_conv2)) + + @torch.no_grad() + def apply_transformation_inout(self, t: Tensor) -> None: + apply_transformation_in(self.pointwise_conv1, t) + apply_transformation_out(self.pointwise_conv2, t) + + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). @@ -1103,6 +1151,15 @@ class Conv2dSubsampling(nn.Module): x = self.out_balancer(x) return x + @torch.no_grad() + def get_diag_covar_out(self) -> Tensor: + return get_diag_covar_out(self.out) + + @torch.no_grad() + def apply_transformation_out(self, t: Tensor) -> None: + apply_transformation_out(self.out, t) + + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py index 281785c4e..2e4e2ee9e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py @@ -59,7 +59,7 @@ def get_diag_covar_in(m: nn.Module) -> Tensor: w = w.reshape(in_channels, -1) return _get_normalized_covar(w) # (in_channels, in_channels) elif isinstance(m, nn.Sequential): - return get_diag_covar_in(m[0]) + return get_diag_covar_in(m[0], t) else: # some modules have this function; if not, at this point, it is an error. return m.get_diag_covar_in() @@ -135,7 +135,7 @@ def apply_transformation_in(m: nn.Module, t: Tensor) -> None: w = w.reshape(m.weight.shape) # (out_channels, in_channels, [1 or 2 kernel dims]) m.weight[:] = w elif isinstance(m, nn.Sequential): - apply_transformation_in(m[0]) + apply_transformation_in(m[0], t) else: # some modules have this function; if not, at this point, it is an error. m.apply_transformation_in(t) @@ -167,7 +167,7 @@ def apply_transformation_out(m: nn.Module, t: Tensor) -> None: if m.bias is not None: m.bias[:] = torch.matmul(t, m.bias) elif isinstance(m, nn.Sequential): - apply_transformation_out(m[-1]) + apply_transformation_out(m[-1], t) else: # some modules have this function; if not, at this point, it is an error. m.apply_transformation_out(t) @@ -193,12 +193,9 @@ def get_transformation(cov: Tensor) -> Tensor: Returns: a transformation indexed (new_dim0, old_dim0), i.e. of shape dim0 by dim0 but 1st index is the newly created indexes. """ - cov = get_normalized_covar(args[0]) - for a in args[1:]: - cov += get_normalized_covar(a) old_diag_stddev = cov.diag().var().sqrt().item() l, U = cov.symeig(eigenvectors=True) new_diag_stddev = l.var().sqrt().item() logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} " - f"to {new_diag_stddev:.3e}") + f"to {new_diag_stddev:.3e}, max diag elem changed from {cov.diag().max().item():.2e} to {l[-1].item():.2e}") return U.t() # U.t() is indexed (new_dim, old_dim) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py index 01b9ecc0e..1bb74adf7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py @@ -21,7 +21,7 @@ from torch import Tensor import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from diagonalize import get_diag_covar_in +from diagonalize import get_diag_covar_in, apply_transformation_in, get_transformation, apply_transformation_in, apply_transformation_out from icefall.utils import add_sos @@ -195,7 +195,73 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss) - def get_diag_covar_in(self) -> Tensor: - return (get_diag_covar_in(self.simple_am_proj) + - get_diag_covar_in(joiner.encoder_proj) + - self.encoder.get_diag_covar_out()) + + def diagonalize(self) -> None: + self.encoder.diagonalize() # diagonalizes self_attn layers. + + diag_covar = (get_diag_covar_in(self.simple_am_proj) + + get_diag_covar_in(self.joiner.encoder_proj) + + self.encoder.get_diag_covar_out()) + t = get_transformation(diag_covar) + self.encoder.apply_transformation_out(t) + apply_transformation_in(self.simple_am_proj, t) + apply_transformation_in(self.joiner.encoder_proj, t) + + + +def _test_model(): + import logging + logging.getLogger().setLevel(logging.INFO) + from conformer import Conformer + from joiner import Joiner + from decoder import Decoder + feature_dim = 40 + attention_dim = 256 + encoder_dim = 512 + decoder_dim = 513 + joiner_dim = 514 + vocab_size = 1000 + encoder = Conformer(num_features=40, + subsampling_factor=4, + d_model=encoder_dim, + nhead=4, + dim_feedforward=512, + num_encoder_layers=4) + decoder = Decoder( + vocab_size=600, + decoder_dim=decoder_dim, + blank_id=0, + context_size=2) + joiner = Joiner( + encoder_dim=encoder_dim, + decoder_dim=decoder_dim, + joiner_dim=joiner_dim, + vocab_size=vocab_size) + model = Transducer(encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=encoder_dim, + decoder_dim=decoder_dim, + joiner_dim=joiner_dim, + vocab_size=vocab_size) + + batch_size = 5 + seq_len = 50 + + feats = torch.randn(batch_size, seq_len, feature_dim) + x_lens = torch.full((batch_size,), seq_len, dtype=torch.int64) + y = k2.ragged.create_ragged_tensor(torch.arange(5, dtype=torch.int32).reshape(1,5).expand(batch_size,5)) + model.eval() # eval mode so it's not random. + (simple_loss1, pruned_loss1) = model(feats, x_lens, y) + model.diagonalize() + (simple_loss2, pruned_loss2) = model(feats, x_lens, y) + model.diagonalize() + + print(f"simple_loss1 = {simple_loss1.mean().item()}, simple_loss2 = {simple_loss2.mean().item()}") + print(f"pruned_loss1 = {pruned_loss1.mean().item()}, pruned_loss2 = {pruned_loss2.mean().item()}") + + + + +if __name__ == '__main__': + _test_model()