Replace dropout2 on Conv2dSubsampling with Dropout3, share time dim

This commit is contained in:
Daniel Povey 2023-01-11 13:18:08 +08:00
parent 1774853bdf
commit 3fdfec1049

View File

@ -2013,7 +2013,7 @@ class Conv2dSubsampling(nn.Module):
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BasicNorm(out_channels)
self.dropout = Dropout2(dropout)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(self, x: torch.Tensor) -> torch.Tensor: