mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Merge Decorrelate work, and simplification to RandomCombine, into pruned_transducer_stateless7
This commit is contained in:
parent
30379bb35d
commit
d301f8ac6c
@ -19,7 +19,7 @@ import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
@ -29,6 +29,7 @@ from scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
Decorrelate,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -93,6 +94,9 @@ class Conformer(EncoderInterface):
|
||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||
)
|
||||
|
||||
self.decorrelate = Decorrelate(d_model, scale=0.05)
|
||||
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -128,6 +132,8 @@ class Conformer(EncoderInterface):
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
x = self.decorrelate(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return x, lengths
|
||||
@ -201,6 +207,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
@ -302,7 +309,6 @@ class ConformerEncoder(nn.Module):
|
||||
num_channels = encoder_layer.norm_final.num_channels
|
||||
self.combiner = RandomCombine(
|
||||
num_inputs=len(self.aux_layers),
|
||||
num_channels=num_channels,
|
||||
final_weight=0.5,
|
||||
pure_prob=0.333,
|
||||
stddev=2.0,
|
||||
@ -371,7 +377,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
@ -1032,6 +1038,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
@ -1073,7 +1080,6 @@ class RandomCombine(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_inputs: int,
|
||||
num_channels: int,
|
||||
final_weight: float = 0.5,
|
||||
pure_prob: float = 0.5,
|
||||
stddev: float = 2.0,
|
||||
@ -1084,8 +1090,6 @@ class RandomCombine(nn.Module):
|
||||
The number of tensor inputs, which equals the number of layers'
|
||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||
num_channels:
|
||||
The number of channels on the input, e.g. 512.
|
||||
final_weight:
|
||||
The amount of weight or probability we assign to the
|
||||
final layer when randomly choosing layers or when choosing
|
||||
@ -1116,13 +1120,6 @@ class RandomCombine(nn.Module):
|
||||
assert 0 < final_weight < 1, final_weight
|
||||
assert num_inputs >= 1
|
||||
|
||||
self.linear = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(num_channels, num_channels, bias=True)
|
||||
for _ in range(num_inputs - 1)
|
||||
]
|
||||
)
|
||||
|
||||
self.num_inputs = num_inputs
|
||||
self.final_weight = final_weight
|
||||
self.pure_prob = pure_prob
|
||||
@ -1135,12 +1132,7 @@ class RandomCombine(nn.Module):
|
||||
.log()
|
||||
.item()
|
||||
)
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
for i in range(len(self.linear)):
|
||||
nn.init.eye_(self.linear[i].weight)
|
||||
nn.init.constant_(self.linear[i].bias, 0.0)
|
||||
|
||||
def forward(self, inputs: List[Tensor]) -> Tensor:
|
||||
"""Forward function.
|
||||
@ -1161,14 +1153,9 @@ class RandomCombine(nn.Module):
|
||||
num_channels = inputs[0].shape[-1]
|
||||
num_frames = inputs[0].numel() // num_channels
|
||||
|
||||
mod_inputs = []
|
||||
for i in range(num_inputs - 1):
|
||||
mod_inputs.append(self.linear[i](inputs[i]))
|
||||
mod_inputs.append(inputs[num_inputs - 1])
|
||||
|
||||
ndim = inputs[0].ndim
|
||||
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape(
|
||||
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
|
||||
(num_frames, num_channels, num_inputs)
|
||||
)
|
||||
|
||||
@ -1282,7 +1269,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
||||
num_channels = 50
|
||||
m = RandomCombine(
|
||||
num_inputs=num_inputs,
|
||||
num_channels=num_channels,
|
||||
final_weight=final_weight,
|
||||
pure_prob=pure_prob,
|
||||
stddev=stddev,
|
||||
@ -1329,5 +1315,8 @@ def _test_conformer_main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_conformer_main()
|
||||
_test_random_combine_main()
|
||||
|
@ -18,7 +18,9 @@
|
||||
import collections
|
||||
from itertools import repeat
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
@ -335,6 +337,263 @@ class DoubleSwish(torch.nn.Module):
|
||||
|
||||
|
||||
|
||||
class GaussProjDrop(torch.nn.Module):
|
||||
"""
|
||||
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
|
||||
The directions of dropout are fixed when the class is initialized, and are orthogonal.
|
||||
|
||||
dropout_rate: the dropout probability (actually will define the number of zeroed-out directions)
|
||||
channel_dim: the axis corresponding to the channel, e.g. -1, 0, 1, 2.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
dropout_rate: float = 0.1,
|
||||
channel_dim: int = -1):
|
||||
super(GaussProjDrop, self).__init__()
|
||||
self.dropout_rate = dropout_rate
|
||||
# this formula for rand_scale was found empirically, trying to match the
|
||||
# statistics of dropout in terms of cross-correlation with the input, see
|
||||
# _test_gauss_proj_drop()
|
||||
self.rand_scale = (dropout_rate / (1-dropout_rate)) ** 0.5 # * (num_channels ** -0.5)
|
||||
|
||||
self.channel_dim = channel_dim
|
||||
|
||||
rand_mat = torch.randn(num_channels, num_channels)
|
||||
U, _, _ = rand_mat.svd()
|
||||
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
||||
|
||||
|
||||
def _randperm_like(self, x: Tensor):
|
||||
"""
|
||||
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
|
||||
with the same shape as x. All dimensions of x other than the last dimension
|
||||
will be treated as batch dimensions.
|
||||
|
||||
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
|
||||
|
||||
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
|
||||
we normally set channel dims. This is required for some number theoretic stuff.
|
||||
"""
|
||||
n = x.shape[-1]
|
||||
|
||||
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
|
||||
|
||||
b = x.numel() // n
|
||||
randint = random.randint(0, 1000)
|
||||
perm = torch.randperm(n, device=x.device)
|
||||
# ensure all elements of batch_rand are coprime to n; this will ensure
|
||||
# that multiplying the permutation by batch_rand and taking modulo
|
||||
# n leaves us with permutations.
|
||||
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
|
||||
batch_rand = batch_rand.unsqueeze(-1)
|
||||
ans = (perm * batch_rand) % n
|
||||
ans = ans.reshape(x.shape)
|
||||
return ans
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if not self.training:
|
||||
return x
|
||||
else:
|
||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||
x_bypass = x # will be used for "+ I"
|
||||
perm = self._randperm_like(x)
|
||||
x = torch.gather(x, -1, perm)
|
||||
# self.U will act like a different matrix for every row of x, because of the random
|
||||
# permutation.
|
||||
x = torch.matmul(x, self.U)
|
||||
x_next = torch.empty_like(x)
|
||||
# scatter_ uses perm in opposite way
|
||||
# from gather, inverting it.
|
||||
x_next.scatter_(-1, perm, x)
|
||||
x = (x_next * self.rand_scale + x_bypass)
|
||||
return x
|
||||
|
||||
|
||||
def _compute_correlation_loss(cov: Tensor,
|
||||
eps: float) -> Tensor:
|
||||
"""
|
||||
Computes the correlation `loss`, which would be zero if the channel dimensions
|
||||
are un-correlated, and equals num_channels if they are maximally correlated
|
||||
(i.e., they all point in the same direction)
|
||||
Args:
|
||||
cov: Uncentered covariance matrix of shape (num_channels, num_channels),
|
||||
does not have to be normalized by count.
|
||||
"""
|
||||
inv_sqrt_diag = (cov.diag() + eps) ** -0.5
|
||||
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||
num_channels = cov.shape[0]
|
||||
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels
|
||||
return loss
|
||||
|
||||
def _update_cov_stats(cov: Tensor,
|
||||
x: Tensor,
|
||||
beta: float) -> Tensor:
|
||||
"""
|
||||
Updates covariance stats as a decaying sum, returning the result.
|
||||
cov: Old covariance stats, to be added to and returned, of shape
|
||||
(num_channels, num_channels)
|
||||
x: Tensor of features/activations, of shape (num_frames, num_channels)
|
||||
beta: The decay constant for the stats, e.g. 0.8.
|
||||
"""
|
||||
new_cov = torch.matmul(x.t(), x)
|
||||
return cov * beta + new_cov * (1-beta)
|
||||
|
||||
|
||||
class DecorrelateFunction(torch.autograd.Function):
|
||||
"""
|
||||
Function object for a function that does nothing in the forward pass;
|
||||
but, in the backward pass, adds derivatives that encourage the channel dimensions
|
||||
to be un-correlated with each other.
|
||||
This should not be used in a half-precision-enabled area, use
|
||||
with torch.cuda.amp.autocast(enabled=False).
|
||||
|
||||
Args:
|
||||
x: The input tensor, which is also the function output. It can have
|
||||
arbitrary shape, but its dimension `channel_dim` (e.g. -1) will be
|
||||
interpreted as the channel dimension.
|
||||
old_cov: Covariance statistics from previous frames, accumulated as a decaying
|
||||
sum over frames (not average), decaying by beta each time.
|
||||
scale: The scale on the derivatives arising from this, expressed as
|
||||
a fraction of the norm of the derivative at the output (applied per
|
||||
frame). We will further scale down the derivative if the normalized
|
||||
loss is less than 1.
|
||||
eps: Epsilon value to prevent division by zero when estimating diagonal
|
||||
covariance.
|
||||
beta: Decay constant that determines how we combine stats from x with
|
||||
stats from cov; e.g. 0.8.
|
||||
channel_dim: The dimension/axis of x corresponding to the channel, e.g. 0, 1, 2, -1.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, old_cov: Tensor,
|
||||
scale: float, eps: float, beta: float,
|
||||
channel_dim: int) -> Tensor:
|
||||
ctx.save_for_backward(x.detach(), old_cov.detach())
|
||||
ctx.scale = scale
|
||||
ctx.eps = eps
|
||||
ctx.beta = beta
|
||||
ctx.channel_dim = channel_dim
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
||||
x, old_cov = ctx.saved_tensors
|
||||
|
||||
# Reshape x and x_grad to be (num_frames, num_channels)
|
||||
x = x.transpose(-1, ctx.channel_dim)
|
||||
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
||||
num_channels = x.shape[-1]
|
||||
full_shape = x.shape
|
||||
x = x.reshape(-1, num_channels)
|
||||
x_grad = x_grad.reshape(-1, num_channels)
|
||||
x.requires_grad = True
|
||||
|
||||
# Now, normalize the contributions of frames/pixels x to the covariance,
|
||||
# to have magnitudes proportional to the norm of the gradient on that
|
||||
# frame; the goal is to exclude "don't-care" frames such as padding frames from
|
||||
# the computation.
|
||||
x_grad_sqrt_norm = (x_grad ** 2).sum(dim=1) ** 0.25
|
||||
x_grad_sqrt_norm /= x_grad_sqrt_norm.mean()
|
||||
x_grad_sqrt_norm_is_inf = (x_grad_sqrt_norm - x_grad_sqrt_norm != 0)
|
||||
x_grad_sqrt_norm.masked_fill_(x_grad_sqrt_norm_is_inf, 1.0)
|
||||
|
||||
with torch.enable_grad():
|
||||
# scale up frames with larger grads.
|
||||
x_scaled = x * x_grad_sqrt_norm.unsqueeze(-1)
|
||||
|
||||
cov = _update_cov_stats(old_cov, x_scaled, ctx.beta)
|
||||
assert old_cov.dtype != torch.float16
|
||||
old_cov[:] = cov # update the stats outside! This is not really
|
||||
# how backprop is supposed to work, but this input
|
||||
# is not differentiable..
|
||||
loss = _compute_correlation_loss(cov, ctx.eps)
|
||||
assert loss.dtype == torch.float32
|
||||
|
||||
if random.random() < 0.05:
|
||||
logging.info(f"Decorrelate: loss = {loss}")
|
||||
|
||||
loss.backward()
|
||||
|
||||
decorr_x_grad = x.grad
|
||||
|
||||
scale = ctx.scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5
|
||||
decorr_x_grad = decorr_x_grad * scale
|
||||
|
||||
x_grad = x_grad + decorr_x_grad
|
||||
|
||||
# reshape back to original shape
|
||||
x_grad = x_grad.reshape(full_shape)
|
||||
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
||||
|
||||
return x_grad, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
class Decorrelate(torch.nn.Module):
|
||||
"""
|
||||
This module does nothing in the forward pass, but in the backward pass, modifies
|
||||
the derivatives in such a way as to encourage the dimensions of its input to become
|
||||
decorrelated.
|
||||
|
||||
Args:
|
||||
num_channels: The number of channels, e.g. 256.
|
||||
apply_prob_decay: The probability with which we apply this each time, in
|
||||
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
|
||||
scale: This number determines the scale of the gradient contribution from
|
||||
this module, relative to whatever the gradient was before;
|
||||
this is applied per frame or pixel, by scaling gradients.
|
||||
eps: An epsilon used to prevent division by zero.
|
||||
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
||||
channel_dim: The dimension of the input corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
scale: float = 0.1,
|
||||
apply_prob_decay: int = 1000,
|
||||
eps: float = 1.0e-05,
|
||||
beta: float = 0.95,
|
||||
channel_dim: int = -1):
|
||||
super(Decorrelate, self).__init__()
|
||||
self.scale = scale
|
||||
self.apply_prob_decay = apply_prob_decay
|
||||
self.eps = eps
|
||||
self.beta = beta
|
||||
self.channel_dim = channel_dim
|
||||
|
||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
||||
# step_buf is a copy of step, included so it will be loaded/saved with
|
||||
# the model.
|
||||
self.register_buffer('step_buf', torch.tensor(0.0))
|
||||
self.step = 0
|
||||
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
super(Decorrelate, self).load_state_dict(*args, **kwargs)
|
||||
self.step = int(self.step_buf.item())
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if not self.training:
|
||||
return x
|
||||
else:
|
||||
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
|
||||
self.step += 1
|
||||
self.step_buf.fill_(float(self.step))
|
||||
if random.random() > apply_prob:
|
||||
return x
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = x.to(torch.float32)
|
||||
# the function updates self.cov in its backward pass (it needs the gradient
|
||||
# norm, for frame weighting).
|
||||
ans = DecorrelateFunction.apply(x, self.cov,
|
||||
self.scale, self.eps, self.beta,
|
||||
self.channel_dim) # == x.
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
def _test_activation_balancer_sign():
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
@ -344,7 +603,7 @@ def _test_activation_balancer_sign():
|
||||
m = ActivationBalancer(
|
||||
channel_dim=0,
|
||||
min_positive=0.05,
|
||||
max_positive=0.95,
|
||||
max_positive=0.98,
|
||||
max_factor=0.2,
|
||||
min_abs=0.0,
|
||||
)
|
||||
@ -408,7 +667,56 @@ def _test_double_swish_deriv():
|
||||
torch.autograd.gradcheck(m, x)
|
||||
|
||||
|
||||
def _test_gauss_proj_drop():
|
||||
D = 384
|
||||
x = torch.randn(30000, D)
|
||||
|
||||
|
||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||
m1 = torch.nn.Dropout(dropout_rate)
|
||||
m2 = GaussProjDrop(D, dropout_rate)
|
||||
for mode in ['train', 'eval']:
|
||||
y1 = m1(x)
|
||||
y2 = m2(x)
|
||||
xmag = (x*x).mean()
|
||||
y1mag = (y1*y1).mean()
|
||||
cross1 = (x*y1).mean()
|
||||
y2mag = (y2*y2).mean()
|
||||
cross2 = (x*y2).mean()
|
||||
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
|
||||
m1.eval()
|
||||
m2.eval()
|
||||
|
||||
|
||||
def _test_decorrelate():
|
||||
D = 384
|
||||
x = torch.randn(30000, D)
|
||||
# give it a non-unit covariance.
|
||||
m = torch.randn(D, D) * (D ** -0.5)
|
||||
_, S, _ = m.svd()
|
||||
print("M eigs = ", S[::10])
|
||||
x = torch.matmul(x, m)
|
||||
|
||||
|
||||
# check that class Decorrelate does not crash when running..
|
||||
decorrelate = Decorrelate(D)
|
||||
x.requires_grad = True
|
||||
y = decorrelate(x)
|
||||
y.sum().backward()
|
||||
|
||||
decorrelate2 = Decorrelate(D)
|
||||
decorrelate2.load_state_dict(decorrelate.state_dict())
|
||||
assert decorrelate2.step == decorrelate.step
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_decorrelate()
|
||||
_test_gauss_proj_drop()
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
|
Loading…
x
Reference in New Issue
Block a user