mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove caching eval
This commit is contained in:
parent
1580c1c1cc
commit
1774853bdf
@ -1908,7 +1908,7 @@ class Dropout3(nn.Module):
|
|||||||
rand_shape[self.shared_dim] = 1
|
rand_shape[self.shared_dim] = 1
|
||||||
mask = torch.rand(*rand_shape, device=x.device) > p
|
mask = torch.rand(*rand_shape, device=x.device) > p
|
||||||
ans = MulForDropout3.apply(x, mask, scale)
|
ans = MulForDropout3.apply(x, mask, scale)
|
||||||
return mask
|
return ans
|
||||||
|
|
||||||
|
|
||||||
class SwooshLFunction(torch.autograd.Function):
|
class SwooshLFunction(torch.autograd.Function):
|
||||||
|
|||||||
@ -1868,7 +1868,9 @@ class ConvNeXt(nn.Module):
|
|||||||
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
return caching_eval(self.forward_internal, x, mask)
|
# turns out this caching idea does not work with --world-size > 1
|
||||||
|
#return caching_eval(self.forward_internal, x, mask)
|
||||||
|
return self.forward_internal(x, mask)
|
||||||
|
|
||||||
|
|
||||||
def forward_internal(self,
|
def forward_internal(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user