Merge Decorrelate work, and simplification to RandomCombine, into pruned_transducer_stateless7

This commit is contained in:
Daniel Povey 2022-06-11 11:07:07 +08:00
parent 30379bb35d
commit d301f8ac6c
2 changed files with 323 additions and 26 deletions

View File

@ -19,7 +19,7 @@ import copy
import math import math
import warnings import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import logging
import torch import torch
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ( from scaling import (
@ -29,6 +29,7 @@ from scaling import (
ScaledConv1d, ScaledConv1d,
ScaledConv2d, ScaledConv2d,
ScaledLinear, ScaledLinear,
Decorrelate,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -93,6 +94,9 @@ class Conformer(EncoderInterface):
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
) )
self.decorrelate = Decorrelate(d_model, scale=0.05)
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -128,6 +132,8 @@ class Conformer(EncoderInterface):
x, pos_emb, src_key_padding_mask=mask, warmup=warmup x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C) ) # (T, N, C)
x = self.decorrelate(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths return x, lengths
@ -201,6 +207,7 @@ class ConformerEncoderLayer(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
@ -302,7 +309,6 @@ class ConformerEncoder(nn.Module):
num_channels = encoder_layer.norm_final.num_channels num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine( self.combiner = RandomCombine(
num_inputs=len(self.aux_layers), num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5, final_weight=0.5,
pure_prob=0.333, pure_prob=0.333,
stddev=2.0, stddev=2.0,
@ -371,7 +377,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""Construct an PositionalEncoding object.""" """Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(1, max_len))
@ -1032,6 +1038,7 @@ class Conv2dSubsampling(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55 channel_dim=-1, min_positive=0.45, max_positive=0.55
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.
@ -1073,7 +1080,6 @@ class RandomCombine(nn.Module):
def __init__( def __init__(
self, self,
num_inputs: int, num_inputs: int,
num_channels: int,
final_weight: float = 0.5, final_weight: float = 0.5,
pure_prob: float = 0.5, pure_prob: float = 0.5,
stddev: float = 2.0, stddev: float = 2.0,
@ -1084,8 +1090,6 @@ class RandomCombine(nn.Module):
The number of tensor inputs, which equals the number of layers' The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3. net if we output layers 16, 12, 18, num_inputs would be 3.
num_channels:
The number of channels on the input, e.g. 512.
final_weight: final_weight:
The amount of weight or probability we assign to the The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing final layer when randomly choosing layers or when choosing
@ -1116,13 +1120,6 @@ class RandomCombine(nn.Module):
assert 0 < final_weight < 1, final_weight assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1 assert num_inputs >= 1
self.linear = nn.ModuleList(
[
nn.Linear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)
]
)
self.num_inputs = num_inputs self.num_inputs = num_inputs
self.final_weight = final_weight self.final_weight = final_weight
self.pure_prob = pure_prob self.pure_prob = pure_prob
@ -1135,12 +1132,7 @@ class RandomCombine(nn.Module):
.log() .log()
.item() .item()
) )
self._reset_parameters()
def _reset_parameters(self):
for i in range(len(self.linear)):
nn.init.eye_(self.linear[i].weight)
nn.init.constant_(self.linear[i].bias, 0.0)
def forward(self, inputs: List[Tensor]) -> Tensor: def forward(self, inputs: List[Tensor]) -> Tensor:
"""Forward function. """Forward function.
@ -1161,14 +1153,9 @@ class RandomCombine(nn.Module):
num_channels = inputs[0].shape[-1] num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels num_frames = inputs[0].numel() // num_channels
mod_inputs = []
for i in range(num_inputs - 1):
mod_inputs.append(self.linear[i](inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs) # stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape( stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs) (num_frames, num_channels, num_inputs)
) )
@ -1282,7 +1269,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
num_channels = 50 num_channels = 50
m = RandomCombine( m = RandomCombine(
num_inputs=num_inputs, num_inputs=num_inputs,
num_channels=num_channels,
final_weight=final_weight, final_weight=final_weight,
pure_prob=pure_prob, pure_prob=pure_prob,
stddev=stddev, stddev=stddev,
@ -1329,5 +1315,8 @@ def _test_conformer_main():
if __name__ == "__main__": if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_conformer_main() _test_conformer_main()
_test_random_combine_main() _test_random_combine_main()

View File

@ -18,7 +18,9 @@
import collections import collections
from itertools import repeat from itertools import repeat
from typing import Optional, Tuple from typing import Optional, Tuple
import logging
import random
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
@ -335,6 +337,263 @@ class DoubleSwish(torch.nn.Module):
class GaussProjDrop(torch.nn.Module):
"""
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
The directions of dropout are fixed when the class is initialized, and are orthogonal.
dropout_rate: the dropout probability (actually will define the number of zeroed-out directions)
channel_dim: the axis corresponding to the channel, e.g. -1, 0, 1, 2.
"""
def __init__(self,
num_channels: int,
dropout_rate: float = 0.1,
channel_dim: int = -1):
super(GaussProjDrop, self).__init__()
self.dropout_rate = dropout_rate
# this formula for rand_scale was found empirically, trying to match the
# statistics of dropout in terms of cross-correlation with the input, see
# _test_gauss_proj_drop()
self.rand_scale = (dropout_rate / (1-dropout_rate)) ** 0.5 # * (num_channels ** -0.5)
self.channel_dim = channel_dim
rand_mat = torch.randn(num_channels, num_channels)
U, _, _ = rand_mat.svd()
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
def _randperm_like(self, x: Tensor):
"""
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
with the same shape as x. All dimensions of x other than the last dimension
will be treated as batch dimensions.
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
we normally set channel dims. This is required for some number theoretic stuff.
"""
n = x.shape[-1]
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
b = x.numel() // n
randint = random.randint(0, 1000)
perm = torch.randperm(n, device=x.device)
# ensure all elements of batch_rand are coprime to n; this will ensure
# that multiplying the permutation by batch_rand and taking modulo
# n leaves us with permutations.
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
batch_rand = batch_rand.unsqueeze(-1)
ans = (perm * batch_rand) % n
ans = ans.reshape(x.shape)
return ans
def forward(self, x: Tensor) -> Tensor:
if not self.training:
return x
else:
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
x_bypass = x # will be used for "+ I"
perm = self._randperm_like(x)
x = torch.gather(x, -1, perm)
# self.U will act like a different matrix for every row of x, because of the random
# permutation.
x = torch.matmul(x, self.U)
x_next = torch.empty_like(x)
# scatter_ uses perm in opposite way
# from gather, inverting it.
x_next.scatter_(-1, perm, x)
x = (x_next * self.rand_scale + x_bypass)
return x
def _compute_correlation_loss(cov: Tensor,
eps: float) -> Tensor:
"""
Computes the correlation `loss`, which would be zero if the channel dimensions
are un-correlated, and equals num_channels if they are maximally correlated
(i.e., they all point in the same direction)
Args:
cov: Uncentered covariance matrix of shape (num_channels, num_channels),
does not have to be normalized by count.
"""
inv_sqrt_diag = (cov.diag() + eps) ** -0.5
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
num_channels = cov.shape[0]
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels
return loss
def _update_cov_stats(cov: Tensor,
x: Tensor,
beta: float) -> Tensor:
"""
Updates covariance stats as a decaying sum, returning the result.
cov: Old covariance stats, to be added to and returned, of shape
(num_channels, num_channels)
x: Tensor of features/activations, of shape (num_frames, num_channels)
beta: The decay constant for the stats, e.g. 0.8.
"""
new_cov = torch.matmul(x.t(), x)
return cov * beta + new_cov * (1-beta)
class DecorrelateFunction(torch.autograd.Function):
"""
Function object for a function that does nothing in the forward pass;
but, in the backward pass, adds derivatives that encourage the channel dimensions
to be un-correlated with each other.
This should not be used in a half-precision-enabled area, use
with torch.cuda.amp.autocast(enabled=False).
Args:
x: The input tensor, which is also the function output. It can have
arbitrary shape, but its dimension `channel_dim` (e.g. -1) will be
interpreted as the channel dimension.
old_cov: Covariance statistics from previous frames, accumulated as a decaying
sum over frames (not average), decaying by beta each time.
scale: The scale on the derivatives arising from this, expressed as
a fraction of the norm of the derivative at the output (applied per
frame). We will further scale down the derivative if the normalized
loss is less than 1.
eps: Epsilon value to prevent division by zero when estimating diagonal
covariance.
beta: Decay constant that determines how we combine stats from x with
stats from cov; e.g. 0.8.
channel_dim: The dimension/axis of x corresponding to the channel, e.g. 0, 1, 2, -1.
"""
@staticmethod
def forward(ctx, x: Tensor, old_cov: Tensor,
scale: float, eps: float, beta: float,
channel_dim: int) -> Tensor:
ctx.save_for_backward(x.detach(), old_cov.detach())
ctx.scale = scale
ctx.eps = eps
ctx.beta = beta
ctx.channel_dim = channel_dim
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
x, old_cov = ctx.saved_tensors
# Reshape x and x_grad to be (num_frames, num_channels)
x = x.transpose(-1, ctx.channel_dim)
x_grad = x_grad.transpose(-1, ctx.channel_dim)
num_channels = x.shape[-1]
full_shape = x.shape
x = x.reshape(-1, num_channels)
x_grad = x_grad.reshape(-1, num_channels)
x.requires_grad = True
# Now, normalize the contributions of frames/pixels x to the covariance,
# to have magnitudes proportional to the norm of the gradient on that
# frame; the goal is to exclude "don't-care" frames such as padding frames from
# the computation.
x_grad_sqrt_norm = (x_grad ** 2).sum(dim=1) ** 0.25
x_grad_sqrt_norm /= x_grad_sqrt_norm.mean()
x_grad_sqrt_norm_is_inf = (x_grad_sqrt_norm - x_grad_sqrt_norm != 0)
x_grad_sqrt_norm.masked_fill_(x_grad_sqrt_norm_is_inf, 1.0)
with torch.enable_grad():
# scale up frames with larger grads.
x_scaled = x * x_grad_sqrt_norm.unsqueeze(-1)
cov = _update_cov_stats(old_cov, x_scaled, ctx.beta)
assert old_cov.dtype != torch.float16
old_cov[:] = cov # update the stats outside! This is not really
# how backprop is supposed to work, but this input
# is not differentiable..
loss = _compute_correlation_loss(cov, ctx.eps)
assert loss.dtype == torch.float32
if random.random() < 0.05:
logging.info(f"Decorrelate: loss = {loss}")
loss.backward()
decorr_x_grad = x.grad
scale = ctx.scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5
decorr_x_grad = decorr_x_grad * scale
x_grad = x_grad + decorr_x_grad
# reshape back to original shape
x_grad = x_grad.reshape(full_shape)
x_grad = x_grad.transpose(-1, ctx.channel_dim)
return x_grad, None, None, None, None, None
class Decorrelate(torch.nn.Module):
"""
This module does nothing in the forward pass, but in the backward pass, modifies
the derivatives in such a way as to encourage the dimensions of its input to become
decorrelated.
Args:
num_channels: The number of channels, e.g. 256.
apply_prob_decay: The probability with which we apply this each time, in
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
scale: This number determines the scale of the gradient contribution from
this module, relative to whatever the gradient was before;
this is applied per frame or pixel, by scaling gradients.
eps: An epsilon used to prevent division by zero.
beta: A value 0 < beta < 1 that controls decay of covariance stats
channel_dim: The dimension of the input corresponding to the channel, e.g.
-1, 0, 1, 2.
"""
def __init__(self,
num_channels: int,
scale: float = 0.1,
apply_prob_decay: int = 1000,
eps: float = 1.0e-05,
beta: float = 0.95,
channel_dim: int = -1):
super(Decorrelate, self).__init__()
self.scale = scale
self.apply_prob_decay = apply_prob_decay
self.eps = eps
self.beta = beta
self.channel_dim = channel_dim
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
# step_buf is a copy of step, included so it will be loaded/saved with
# the model.
self.register_buffer('step_buf', torch.tensor(0.0))
self.step = 0
def load_state_dict(self, *args, **kwargs):
super(Decorrelate, self).load_state_dict(*args, **kwargs)
self.step = int(self.step_buf.item())
def forward(self, x: Tensor) -> Tensor:
if not self.training:
return x
else:
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
self.step += 1
self.step_buf.fill_(float(self.step))
if random.random() > apply_prob:
return x
with torch.cuda.amp.autocast(enabled=False):
x = x.to(torch.float32)
# the function updates self.cov in its backward pass (it needs the gradient
# norm, for frame weighting).
ans = DecorrelateFunction.apply(x, self.cov,
self.scale, self.eps, self.beta,
self.channel_dim) # == x.
return ans
def _test_activation_balancer_sign(): def _test_activation_balancer_sign():
probs = torch.arange(0, 1, 0.01) probs = torch.arange(0, 1, 0.01)
N = 1000 N = 1000
@ -344,7 +603,7 @@ def _test_activation_balancer_sign():
m = ActivationBalancer( m = ActivationBalancer(
channel_dim=0, channel_dim=0,
min_positive=0.05, min_positive=0.05,
max_positive=0.95, max_positive=0.98,
max_factor=0.2, max_factor=0.2,
min_abs=0.0, min_abs=0.0,
) )
@ -408,7 +667,56 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x) torch.autograd.gradcheck(m, x)
def _test_gauss_proj_drop():
D = 384
x = torch.randn(30000, D)
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate)
m2 = GaussProjDrop(D, dropout_rate)
for mode in ['train', 'eval']:
y1 = m1(x)
y2 = m2(x)
xmag = (x*x).mean()
y1mag = (y1*y1).mean()
cross1 = (x*y1).mean()
y2mag = (y2*y2).mean()
cross2 = (x*y2).mean()
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
m1.eval()
m2.eval()
def _test_decorrelate():
D = 384
x = torch.randn(30000, D)
# give it a non-unit covariance.
m = torch.randn(D, D) * (D ** -0.5)
_, S, _ = m.svd()
print("M eigs = ", S[::10])
x = torch.matmul(x, m)
# check that class Decorrelate does not crash when running..
decorrelate = Decorrelate(D)
x.requires_grad = True
y = decorrelate(x)
y.sum().backward()
decorrelate2 = Decorrelate(D)
decorrelate2.load_state_dict(decorrelate.state_dict())
assert decorrelate2.step == decorrelate.step
if __name__ == "__main__": if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_decorrelate()
_test_gauss_proj_drop()
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_basic_norm() _test_basic_norm()