mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Bug fixes.
This commit is contained in:
parent
136ffb0597
commit
e535887abb
@ -655,9 +655,9 @@ class ProjDrop(torch.nn.Module):
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
||||
if self.training:
|
||||
return x * (1.0 - self.dropout_rate)
|
||||
if not self.training:
|
||||
# The ** 0.5 is intended to reproduce the scale on (x**2).sum().
|
||||
return x * ((1.0 - self.dropout_rate) ** 0.5)
|
||||
else:
|
||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||
num_channels = x.shape[-1]
|
||||
@ -667,15 +667,14 @@ class ProjDrop(torch.nn.Module):
|
||||
device=x.device, dtype=x.dtype)
|
||||
rr = torch.matmul(r, r.t()) # num_dropped by num_dropped
|
||||
rr += 0.01 # to 100% ensure it is invertible
|
||||
rr_inv = rr.cholesky.cholesky_inverse()
|
||||
rr_inv = rr.cholesky().cholesky_inverse()
|
||||
# OK, so r rr_inv r.t() will have eigenvalues of 1.
|
||||
xr = torch.matmul(x, r.t()) # (..., num_dropped)
|
||||
rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels)
|
||||
xrr = torch.matmul(xr, rr_inv_r) # (..., num_channels)
|
||||
return x - xrr
|
||||
|
||||
|
||||
|
||||
x = x - xrr
|
||||
x = x.transpose(self.channel_dim, -1)
|
||||
return x
|
||||
|
||||
|
||||
def _test_activation_balancer_sign():
|
||||
@ -752,14 +751,18 @@ def _test_double_swish_deriv():
|
||||
|
||||
|
||||
def _test_proj_drop():
|
||||
x = torch.randn(3000, 300)
|
||||
x = torch.randn(30000, 300)
|
||||
m = ProjDrop(0.1)
|
||||
y = m(x)
|
||||
xmag = (x*x).sqrt().mean()
|
||||
ymag = (y*y).sqrt().mean()
|
||||
xmag = (x*x).mean()
|
||||
ymag = (y*y).mean()
|
||||
print(f"xmag = {xmag}, ymag = {ymag}")
|
||||
assert abs((ymag / xmag) - 0.9) < 0.01
|
||||
|
||||
assert abs((ymag / xmag) - 0.9) < 0.02
|
||||
m.eval()
|
||||
y = m(x)
|
||||
ymag = (y*y).mean()
|
||||
print(f"xmag[eval] = {xmag}, ymag = {ymag}")
|
||||
assert abs((ymag / xmag) - 0.9) < 0.02
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_proj_drop()
|
||||
|
@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.dropout = ProjDrop(p=dropout_rate)
|
||||
self.dropout = ProjDrop(dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user