Bug fixes.

This commit is contained in:
Daniel Povey 2022-06-05 23:24:02 +08:00
parent 136ffb0597
commit e535887abb
2 changed files with 17 additions and 14 deletions

View File

@ -655,9 +655,9 @@ class ProjDrop(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if self.training:
return x * (1.0 - self.dropout_rate)
if not self.training:
# The ** 0.5 is intended to reproduce the scale on (x**2).sum().
return x * ((1.0 - self.dropout_rate) ** 0.5)
else:
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
num_channels = x.shape[-1]
@ -667,15 +667,14 @@ class ProjDrop(torch.nn.Module):
device=x.device, dtype=x.dtype)
rr = torch.matmul(r, r.t()) # num_dropped by num_dropped
rr += 0.01 # to 100% ensure it is invertible
rr_inv = rr.cholesky.cholesky_inverse()
rr_inv = rr.cholesky().cholesky_inverse()
# OK, so r rr_inv r.t() will have eigenvalues of 1.
xr = torch.matmul(x, r.t()) # (..., num_dropped)
rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels)
xrr = torch.matmul(xr, rr_inv_r) # (..., num_channels)
return x - xrr
x = x - xrr
x = x.transpose(self.channel_dim, -1)
return x
def _test_activation_balancer_sign():
@ -752,14 +751,18 @@ def _test_double_swish_deriv():
def _test_proj_drop():
x = torch.randn(3000, 300)
x = torch.randn(30000, 300)
m = ProjDrop(0.1)
y = m(x)
xmag = (x*x).sqrt().mean()
ymag = (y*y).sqrt().mean()
xmag = (x*x).mean()
ymag = (y*y).mean()
print(f"xmag = {xmag}, ymag = {ymag}")
assert abs((ymag / xmag) - 0.9) < 0.01
assert abs((ymag / xmag) - 0.9) < 0.02
m.eval()
y = m(x)
ymag = (y*y).mean()
print(f"xmag[eval] = {xmag}, ymag = {ymag}")
assert abs((ymag / xmag) - 0.9) < 0.02
if __name__ == "__main__":
_test_proj_drop()

View File

@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.dropout = ProjDrop(p=dropout_rate)
self.dropout = ProjDrop(dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))