Remove caching eval

This commit is contained in:
Daniel Povey 2023-01-11 13:03:53 +08:00
parent 1580c1c1cc
commit 1774853bdf
2 changed files with 4 additions and 2 deletions

View File

@ -1908,7 +1908,7 @@ class Dropout3(nn.Module):
rand_shape[self.shared_dim] = 1
mask = torch.rand(*rand_shape, device=x.device) > p
ans = MulForDropout3.apply(x, mask, scale)
return mask
return ans
class SwooshLFunction(torch.autograd.Function):

View File

@ -1868,7 +1868,9 @@ class ConvNeXt(nn.Module):
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
else:
mask = None
return caching_eval(self.forward_internal, x, mask)
# turns out this caching idea does not work with --world-size > 1
#return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(self,