mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Rename orthogonalize to diagonalize
This commit is contained in:
parent
9859e33c06
commit
8aeaf1421a
@ -19,7 +19,7 @@ import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
@ -127,11 +127,6 @@ class Conformer(EncoderInterface):
|
||||
|
||||
return x, lengths
|
||||
|
||||
def orthogonalize(self) -> None:
|
||||
# currently only orthogonalize the self-attention modules, but could in principle
|
||||
# do more layers.
|
||||
for m in self.encoder.layers:
|
||||
m.self_attn.orthogonalize()
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
@ -326,7 +321,6 @@ class ConformerEncoder(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
@ -835,79 +829,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
@torch.no_grad()
|
||||
def orthogonalize(self) -> None:
|
||||
"""
|
||||
Rotate some parameters to try to segregate large vs. small rows/columns of
|
||||
the parameter matrices. The intention is to improve the convergence of
|
||||
Adam-type update formulas (with SGD, the update would be invariant to
|
||||
such rotations).
|
||||
"""
|
||||
num_heads = self.num_heads
|
||||
attn_dim = self.in_proj.weight.shape[1]
|
||||
query_proj = self.in_proj.weight[0:attn_dim,:].chunk(num_heads, dim=0)
|
||||
key_proj = self.in_proj.weight[attn_dim:2*attn_dim,:].chunk(num_heads, dim=0)
|
||||
linear_pos_proj = self.linear_pos.weight.chunk(num_heads, dim=0)
|
||||
value_proj = self.in_proj.weight[2*attn_dim:,:].chunk(num_heads, dim=0)
|
||||
out_proj = self.out_proj.weight.t().chunk(num_heads, dim=0)
|
||||
|
||||
def get_normalized_covar(x: Tensor) -> Tensor:
|
||||
"""
|
||||
Returns a covariance matrix normalized to have trace==dim.
|
||||
Args:
|
||||
x: a matrix of shape (i, j)
|
||||
Returns: a covariance matrix of shape (i, i), equal to matmul(x, x.t())
|
||||
"""
|
||||
covar = torch.matmul(x, x.t())
|
||||
return covar * (x.shape[0] / covar.trace())
|
||||
|
||||
|
||||
def get_proj(*args) -> Tensor:
|
||||
"""
|
||||
Returns a covariance-diagonalizing projection that diagonalizes
|
||||
the summed covariance from these two projections. If mat1,mat2
|
||||
are of shape (dim0, dim1), it's the (dim0, dim0) covariance,
|
||||
that we diagonalize.
|
||||
|
||||
Args: mat1, mat2, etc., which should all be matrices of the same
|
||||
shape (dim0, dim1)
|
||||
|
||||
Returns: a projection indexed (new_dim0, old_dim0), i.e. of
|
||||
shape dim0 by dim0 but 1st index is the newly created indexes.
|
||||
"""
|
||||
cov = get_normalized_covar(args[0])
|
||||
for a in args[1:]:
|
||||
cov += get_normalized_covar(a)
|
||||
old_diag_stddev = cov.diag().var().sqrt().item()
|
||||
l, U = cov.symeig(eigenvectors=True)
|
||||
new_diag_stddev = l.var().sqrt().item()
|
||||
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
|
||||
f"to {new_diag_stddev:.3e}")
|
||||
return U.t() # U.t() is indexed (new_dim, old_dim)
|
||||
|
||||
# first do the query/key space
|
||||
logging.info("Diagonalizing query/key space")
|
||||
for i in range(num_heads):
|
||||
q,k,l,pos_u,pos_v = query_proj[i], key_proj[i], linear_pos_proj[i], self.pos_bias_u[i], self.pos_bias_v[i]
|
||||
qk_proj = get_proj(q, k, l)
|
||||
q[:] = torch.matmul(qk_proj, q)
|
||||
k[:] = torch.matmul(qk_proj, k)
|
||||
l[:] = torch.matmul(qk_proj, l)
|
||||
pos_u[:] = torch.mv(qk_proj, pos_u)
|
||||
pos_v[:] = torch.mv(qk_proj, pos_v)
|
||||
|
||||
# Now do the value space
|
||||
logging.info("Diagonalizing value space")
|
||||
for i in range(num_heads):
|
||||
v, o = value_proj[i], out_proj[i]
|
||||
v_proj = get_proj(v, o)
|
||||
v[:] = torch.matmul(v_proj, v)
|
||||
o[:] = torch.matmul(v_proj, o)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
@ -1105,22 +1026,13 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
# Make sure the forward pass runs, and that orthogonalize() does not
|
||||
# change its output.
|
||||
feats = torch.randn(batch_size, seq_len, feature_dim)
|
||||
x_lens = torch.full((batch_size,), seq_len, dtype=torch.int64)
|
||||
|
||||
c.eval() # use test mode and warmup=1.0 so it is deterministic.
|
||||
y1 = c(feats, x_lens, warmup=1.0)[0]
|
||||
c.orthogonalize()
|
||||
y2 = c(feats, x_lens, warmup=1.0)[0]
|
||||
|
||||
diff_norm = (y1-y2).norm()
|
||||
y_norm = y1.norm()
|
||||
print(f"diff_norm={diff_norm}, y_norm={y_norm}")
|
||||
assert diff_norm < 0.001 * y_norm
|
||||
# Just make sure the forward pass runs.
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5,
|
||||
)
|
||||
|
@ -1 +0,0 @@
|
||||
../pruned_transducer_stateless2/conformer.py
|
1126
egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py
Normal file
1126
egs/librispeech/ASR/pruned_transducer_stateless4b/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -703,7 +703,7 @@ def train_one_epoch(
|
||||
|
||||
if params.batch_idx_train % 2000 == 0 and params.batch_idx_train > 0:
|
||||
mmodel = model.module if hasattr(model, 'module') else model
|
||||
mmodel.encoder.orthogonalize()
|
||||
mmodel.encoder.diagonalize()
|
||||
#optimizer.reset()
|
||||
|
||||
params.batch_idx_train += 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user