From e4a3b2da7da1aaf858c92f587f92dd4c62d6d5b1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Nov 2022 17:40:09 +0800 Subject: [PATCH] Mostly-cosmetic fixes found via mypy --- .../pruned_transducer_stateless7/scaling.py | 34 ++++++++----------- .../pruned_transducer_stateless7/zipformer.py | 4 +-- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index c7803b871..b9322f013 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -167,7 +167,12 @@ class RandomClampFunction(torch.autograd.Function): max: Optional[float], prob: float, reflect: float) -> Tensor: - x_clamped = torch.clamp(x, min=min, max=max) + kwargs = {} + if min is not None: + kwargs['min'] = min + if max is not None: + kwargs['max'] = max + x_clamped = torch.clamp(x, **kwargs) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) if x.requires_grad: @@ -418,7 +423,7 @@ def ScaledLinear(*args, def ScaledConv1d(*args, initial_scale: float = 1.0, - **kwargs ) -> nn.Linear: + **kwargs ) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -497,6 +502,11 @@ class ActivationBalancer(torch.nn.Module): min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() + + # CAUTION: this code expects self.batch_count to be overwritten in the main training + # loop. + self.batch_count = 0 + self.num_channels = num_channels self.channel_dim = channel_dim self.min_positive = min_positive @@ -508,11 +518,7 @@ class ActivationBalancer(torch.nn.Module): self.sign_gain_factor = sign_gain_factor self.scale_gain_factor = scale_gain_factor - # count measures how many times the forward() function has been called. - # We occasionally sync this to a tensor called `count`, that exists to - # make sure it is synced to disk when we load and save the model. - self.cpu_count = 0 - self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) + @@ -520,19 +526,9 @@ class ActivationBalancer(torch.nn.Module): if torch.jit.is_scripting() or not x.requires_grad: return _no_op(x) - count = self.cpu_count - self.cpu_count += 1 - - if random.random() < 0.01: - # Occasionally sync self.cpu_count with self.count. - # count affects the decay of 'prob'. don't do this on every iter, - # because syncing with the GPU is slow. - self.cpu_count = max(self.cpu_count, self.count.item()) - self.count.fill_(self.cpu_count) - # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 @@ -890,7 +886,7 @@ class MaxEig(torch.nn.Module): def _find_direction_coeffs(self, x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + prev_direction: Tensor) -> Tuple[Tensor, Tensor]: """ Figure out (an approximation to) the proportion of the variance of a set of feature vectors that can be attributed to the top eigen-direction. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index afaf864f0..949af6c19 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -781,7 +781,7 @@ class DownsamplingZipformerEncoder(nn.Module): feature_mask: Union[Tensor, float] = 1.0, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: r"""Downsample, go through encoder, upsample. Args: @@ -914,7 +914,7 @@ class SimpleCombiner(torch.nn.Module): def __init__(self, dim1: int, dim2: int, - min_weight: Tuple[float] = (0., 0.)): + min_weight: Tuple[float, float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1 initial_weight1 = 0.1