From b736bb484089bfa4fe2b7ebe35d234b23e700fdb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Oct 2022 19:34:48 +0800 Subject: [PATCH] Cosmetic improvements --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 185f7c98d..ca5412d32 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -487,7 +487,7 @@ class ConformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={warmup_begin}, warmup_end={warmup_end}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info(f"warmup_begin={warmup_begin:.1f}, warmup_end={warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 5d63137ff..e74acb7fe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -444,8 +444,8 @@ class MaxEig(torch.nn.Module): with torch.cuda.amp.autocast(enabled=False): eps = 1.0e-20 - assert x.dtype != torch.float16 orig_x = x + x = x.to(torch.float32) with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) @@ -461,7 +461,7 @@ class MaxEig(torch.nn.Module): # ensure new direction is nonzero even if x == 0, by including `direction`. self._set_direction(0.1 * self.max_eig_direction + new_direction) - if random.random() < 0.0005 or __name__ == "__main__": + if random.random() < 0.01 or __name__ == "__main__": logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") if variance_proportion >= self.max_var_per_eig: