mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove caching eval
This commit is contained in:
parent
1580c1c1cc
commit
1774853bdf
@ -1908,7 +1908,7 @@ class Dropout3(nn.Module):
|
||||
rand_shape[self.shared_dim] = 1
|
||||
mask = torch.rand(*rand_shape, device=x.device) > p
|
||||
ans = MulForDropout3.apply(x, mask, scale)
|
||||
return mask
|
||||
return ans
|
||||
|
||||
|
||||
class SwooshLFunction(torch.autograd.Function):
|
||||
|
||||
@ -1868,7 +1868,9 @@ class ConvNeXt(nn.Module):
|
||||
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
||||
else:
|
||||
mask = None
|
||||
return caching_eval(self.forward_internal, x, mask)
|
||||
# turns out this caching idea does not work with --world-size > 1
|
||||
#return caching_eval(self.forward_internal, x, mask)
|
||||
return self.forward_internal(x, mask)
|
||||
|
||||
|
||||
def forward_internal(self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user