mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Cosmetic improvements
This commit is contained in:
parent
12323025d7
commit
b736bb4840
@ -487,7 +487,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
if len(ans) == num_to_drop:
|
if len(ans) == num_to_drop:
|
||||||
break
|
break
|
||||||
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
||||||
logging.info(f"warmup_begin={warmup_begin}, warmup_end={warmup_end}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
logging.info(f"warmup_begin={warmup_begin:.1f}, warmup_end={warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -444,8 +444,8 @@ class MaxEig(torch.nn.Module):
|
|||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
eps = 1.0e-20
|
eps = 1.0e-20
|
||||||
assert x.dtype != torch.float16
|
|
||||||
orig_x = x
|
orig_x = x
|
||||||
|
x = x.to(torch.float32)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
|
x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
|
||||||
x = x - x.mean(dim=0)
|
x = x - x.mean(dim=0)
|
||||||
@ -461,7 +461,7 @@ class MaxEig(torch.nn.Module):
|
|||||||
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
||||||
self._set_direction(0.1 * self.max_eig_direction + new_direction)
|
self._set_direction(0.1 * self.max_eig_direction + new_direction)
|
||||||
|
|
||||||
if random.random() < 0.0005 or __name__ == "__main__":
|
if random.random() < 0.01 or __name__ == "__main__":
|
||||||
logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}")
|
logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}")
|
||||||
|
|
||||||
if variance_proportion >= self.max_var_per_eig:
|
if variance_proportion >= self.max_var_per_eig:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user