mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Change dropout_rate from 0.2 to 0.1; fix logging statement; fix assignment to rand_scales, nonrand_scales to use [:]
This commit is contained in:
parent
a6050cb2de
commit
135be1e19c
@ -786,7 +786,7 @@ class Decorrelate(torch.nn.Module):
|
||||
U, S, _ = norm_cov.svd()
|
||||
|
||||
if random.random() < 0.1:
|
||||
print("Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
|
||||
logging.info(f"Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
|
||||
|
||||
# row indexes of U correspond to channels, column indexes correspond to
|
||||
# singular values: cov = U * diag(S) * U.t() where * is matmul.
|
||||
@ -817,8 +817,8 @@ class Decorrelate(torch.nn.Module):
|
||||
|
||||
# rand_proportion is viewed as representing a proportion of the covariance, since
|
||||
# the random and nonrandom components will not be correlated.
|
||||
self.rand_scales = rand_proportion.sqrt()
|
||||
self.nonrand_scales = (1.0 - rand_proportion).sqrt()
|
||||
self.rand_scales[:] = rand_proportion.sqrt()
|
||||
self.nonrand_scales[:] = (1.0 - rand_proportion).sqrt()
|
||||
|
||||
|
||||
if True:
|
||||
|
@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.2)
|
||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.1)
|
||||
|
||||
|
||||
def forward(
|
||||
|
Loading…
x
Reference in New Issue
Block a user