Mostly-cosmetic fixes found via mypy

This commit is contained in:
Daniel Povey 2022-11-09 17:40:09 +08:00
parent 308059edba
commit e4a3b2da7d
2 changed files with 17 additions and 21 deletions

View File

@ -167,7 +167,12 @@ class RandomClampFunction(torch.autograd.Function):
max: Optional[float], max: Optional[float],
prob: float, prob: float,
reflect: float) -> Tensor: reflect: float) -> Tensor:
x_clamped = torch.clamp(x, min=min, max=max) kwargs = {}
if min is not None:
kwargs['min'] = min
if max is not None:
kwargs['max'] = max
x_clamped = torch.clamp(x, **kwargs)
mask = torch.rand_like(x) < prob mask = torch.rand_like(x) < prob
ans = torch.where(mask, x_clamped, x) ans = torch.where(mask, x_clamped, x)
if x.requires_grad: if x.requires_grad:
@ -418,7 +423,7 @@ def ScaledLinear(*args,
def ScaledConv1d(*args, def ScaledConv1d(*args,
initial_scale: float = 1.0, initial_scale: float = 1.0,
**kwargs ) -> nn.Linear: **kwargs ) -> nn.Conv1d:
""" """
Behaves like a constructor of a modified version of nn.Conv1d Behaves like a constructor of a modified version of nn.Conv1d
that gives an easy way to set the default initial parameter scale. that gives an easy way to set the default initial parameter scale.
@ -497,6 +502,11 @@ class ActivationBalancer(torch.nn.Module):
min_prob: float = 0.1, min_prob: float = 0.1,
): ):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
# CAUTION: this code expects self.batch_count to be overwritten in the main training
# loop.
self.batch_count = 0
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.min_positive = min_positive self.min_positive = min_positive
@ -508,11 +518,7 @@ class ActivationBalancer(torch.nn.Module):
self.sign_gain_factor = sign_gain_factor self.sign_gain_factor = sign_gain_factor
self.scale_gain_factor = scale_gain_factor self.scale_gain_factor = scale_gain_factor
# count measures how many times the forward() function has been called.
# We occasionally sync this to a tensor called `count`, that exists to
# make sure it is synced to disk when we load and save the model.
self.cpu_count = 0
self.register_buffer('count', torch.tensor(0, dtype=torch.int64))
@ -520,19 +526,9 @@ class ActivationBalancer(torch.nn.Module):
if torch.jit.is_scripting() or not x.requires_grad: if torch.jit.is_scripting() or not x.requires_grad:
return _no_op(x) return _no_op(x)
count = self.cpu_count
self.cpu_count += 1
if random.random() < 0.01:
# Occasionally sync self.cpu_count with self.count.
# count affects the decay of 'prob'. don't do this on every iter,
# because syncing with the GPU is slow.
self.cpu_count = max(self.cpu_count, self.count.item())
self.count.fill_(self.cpu_count)
# the prob of doing some work exponentially decreases from 0.5 till it hits # the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default) # a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
if random.random() < prob: if random.random() < prob:
sign_gain_factor = 0.5 sign_gain_factor = 0.5
@ -890,7 +886,7 @@ class MaxEig(torch.nn.Module):
def _find_direction_coeffs(self, def _find_direction_coeffs(self,
x: Tensor, x: Tensor,
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: prev_direction: Tensor) -> Tuple[Tensor, Tensor]:
""" """
Figure out (an approximation to) the proportion of the variance of a set of Figure out (an approximation to) the proportion of the variance of a set of
feature vectors that can be attributed to the top eigen-direction. feature vectors that can be attributed to the top eigen-direction.

View File

@ -781,7 +781,7 @@ class DownsamplingZipformerEncoder(nn.Module):
feature_mask: Union[Tensor, float] = 1.0, feature_mask: Union[Tensor, float] = 1.0,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tensor:
r"""Downsample, go through encoder, upsample. r"""Downsample, go through encoder, upsample.
Args: Args:
@ -914,7 +914,7 @@ class SimpleCombiner(torch.nn.Module):
def __init__(self, def __init__(self,
dim1: int, dim1: int,
dim2: int, dim2: int,
min_weight: Tuple[float] = (0., 0.)): min_weight: Tuple[float, float] = (0., 0.)):
super(SimpleCombiner, self).__init__() super(SimpleCombiner, self).__init__()
assert dim2 >= dim1 assert dim2 >= dim1
initial_weight1 = 0.1 initial_weight1 = 0.1