mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Orthogonalize every 2k iters
This commit is contained in:
parent
bb32556f9e
commit
cee5396058
@ -19,7 +19,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import (
|
from scaling import (
|
||||||
@ -127,6 +127,11 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
def orthogonalize(self) -> None:
|
||||||
|
# currently only orthogonalize the self-attention modules, but could in principle
|
||||||
|
# do more layers.
|
||||||
|
for m in self.encoder.layers:
|
||||||
|
m.self_attn.orthogonalize()
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -321,6 +326,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RelPositionalEncoding(torch.nn.Module):
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
"""Relative positional encoding module.
|
"""Relative positional encoding module.
|
||||||
|
|
||||||
@ -829,6 +835,79 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def orthogonalize(self) -> None:
|
||||||
|
"""
|
||||||
|
Rotate some parameters to try to segregate large vs. small rows/columns of
|
||||||
|
the parameter matrices. The intention is to improve the convergence of
|
||||||
|
Adam-type update formulas (with SGD, the update would be invariant to
|
||||||
|
such rotations).
|
||||||
|
"""
|
||||||
|
num_heads = self.num_heads
|
||||||
|
attn_dim = self.in_proj.weight.shape[1]
|
||||||
|
query_proj = self.in_proj.weight[0:attn_dim,:].chunk(num_heads, dim=0)
|
||||||
|
key_proj = self.in_proj.weight[attn_dim:2*attn_dim,:].chunk(num_heads, dim=0)
|
||||||
|
linear_pos_proj = self.linear_pos.weight.chunk(num_heads, dim=0)
|
||||||
|
value_proj = self.in_proj.weight[2*attn_dim:,:].chunk(num_heads, dim=0)
|
||||||
|
out_proj = self.out_proj.weight.t().chunk(num_heads, dim=0)
|
||||||
|
|
||||||
|
def get_normalized_covar(x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns a covariance matrix normalized to have trace==dim.
|
||||||
|
Args:
|
||||||
|
x: a matrix of shape (i, j)
|
||||||
|
Returns: a covariance matrix of shape (i, i), equal to matmul(x, x.t())
|
||||||
|
"""
|
||||||
|
covar = torch.matmul(x, x.t())
|
||||||
|
return covar * (x.shape[0] / covar.trace())
|
||||||
|
|
||||||
|
|
||||||
|
def get_proj(*args) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns a covariance-diagonalizing projection that diagonalizes
|
||||||
|
the summed covariance from these two projections. If mat1,mat2
|
||||||
|
are of shape (dim0, dim1), it's the (dim0, dim0) covariance,
|
||||||
|
that we diagonalize.
|
||||||
|
|
||||||
|
Args: mat1, mat2, etc., which should all be matrices of the same
|
||||||
|
shape (dim0, dim1)
|
||||||
|
|
||||||
|
Returns: a projection indexed (new_dim0, old_dim0), i.e. of
|
||||||
|
shape dim0 by dim0 but 1st index is the newly created indexes.
|
||||||
|
"""
|
||||||
|
cov = get_normalized_covar(args[0])
|
||||||
|
for a in args[1:]:
|
||||||
|
cov += get_normalized_covar(a)
|
||||||
|
old_diag_stddev = cov.diag().var().sqrt().item()
|
||||||
|
l, U = cov.symeig(eigenvectors=True)
|
||||||
|
new_diag_stddev = l.var().sqrt().item()
|
||||||
|
logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} "
|
||||||
|
f"to {new_diag_stddev:.3e}")
|
||||||
|
return U.t() # U.t() is indexed (new_dim, old_dim)
|
||||||
|
|
||||||
|
# first do the query/key space
|
||||||
|
logging.info("Diagonalizing query/key space")
|
||||||
|
for i in range(num_heads):
|
||||||
|
q,k,l,pos_u,pos_v = query_proj[i], key_proj[i], linear_pos_proj[i], self.pos_bias_u[i], self.pos_bias_v[i]
|
||||||
|
qk_proj = get_proj(q, k, l)
|
||||||
|
q[:] = torch.matmul(qk_proj, q)
|
||||||
|
k[:] = torch.matmul(qk_proj, k)
|
||||||
|
l[:] = torch.matmul(qk_proj, l)
|
||||||
|
pos_u[:] = torch.mv(qk_proj, pos_u)
|
||||||
|
pos_v[:] = torch.mv(qk_proj, pos_v)
|
||||||
|
|
||||||
|
# Now do the value space
|
||||||
|
logging.info("Diagonalizing value space")
|
||||||
|
for i in range(num_heads):
|
||||||
|
v, o = value_proj[i], out_proj[i]
|
||||||
|
v_proj = get_proj(v, o)
|
||||||
|
v[:] = torch.matmul(v_proj, v)
|
||||||
|
o[:] = torch.matmul(v_proj, o)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Conformer model.
|
"""ConvolutionModule in Conformer model.
|
||||||
@ -1026,13 +1105,22 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
# Just make sure the forward pass runs.
|
# Make sure the forward pass runs, and that orthogonalize() does not
|
||||||
f = c(
|
# change its output.
|
||||||
torch.randn(batch_size, seq_len, feature_dim),
|
feats = torch.randn(batch_size, seq_len, feature_dim)
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
x_lens = torch.full((batch_size,), seq_len, dtype=torch.int64)
|
||||||
warmup=0.5,
|
|
||||||
)
|
c.eval() # use test mode and warmup=1.0 so it is deterministic.
|
||||||
|
y1 = c(feats, x_lens, warmup=1.0)[0]
|
||||||
|
c.orthogonalize()
|
||||||
|
y2 = c(feats, x_lens, warmup=1.0)[0]
|
||||||
|
|
||||||
|
diff_norm = (y1-y2).norm()
|
||||||
|
y_norm = y1.norm()
|
||||||
|
print(f"diff_norm={diff_norm}, y_norm={y_norm}")
|
||||||
|
assert diff_norm < 0.001 * y_norm
|
||||||
|
@ -701,6 +701,11 @@ def train_one_epoch(
|
|||||||
continue
|
continue
|
||||||
cur_batch_idx = batch_idx
|
cur_batch_idx = batch_idx
|
||||||
|
|
||||||
|
if params.batch_idx_train % 2000 == 0 and params.batch_idx_train > 0:
|
||||||
|
mmodel = model.module if hasattr(model, 'module') else model
|
||||||
|
mmodel.encoder.orthogonalize()
|
||||||
|
optimizer.reset()
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user