From 1774853bdf20da7be22ba0ebe32f9af51034f36c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 11 Jan 2023 13:03:53 +0800 Subject: [PATCH] Remove caching eval --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index f68051938..be7a4abd6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1908,7 +1908,7 @@ class Dropout3(nn.Module): rand_shape[self.shared_dim] = 1 mask = torch.rand(*rand_shape, device=x.device) > p ans = MulForDropout3.apply(x, mask, scale) - return mask + return ans class SwooshLFunction(torch.autograd.Function): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 9b0100c14..b2d99c83f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1868,7 +1868,9 @@ class ConvNeXt(nn.Module): mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate else: mask = None - return caching_eval(self.forward_internal, x, mask) + # turns out this caching idea does not work with --world-size > 1 + #return caching_eval(self.forward_internal, x, mask) + return self.forward_internal(x, mask) def forward_internal(self,