mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Mostly-cosmetic fixes found via mypy
This commit is contained in:
parent
308059edba
commit
e4a3b2da7d
@ -167,7 +167,12 @@ class RandomClampFunction(torch.autograd.Function):
|
||||
max: Optional[float],
|
||||
prob: float,
|
||||
reflect: float) -> Tensor:
|
||||
x_clamped = torch.clamp(x, min=min, max=max)
|
||||
kwargs = {}
|
||||
if min is not None:
|
||||
kwargs['min'] = min
|
||||
if max is not None:
|
||||
kwargs['max'] = max
|
||||
x_clamped = torch.clamp(x, **kwargs)
|
||||
mask = torch.rand_like(x) < prob
|
||||
ans = torch.where(mask, x_clamped, x)
|
||||
if x.requires_grad:
|
||||
@ -418,7 +423,7 @@ def ScaledLinear(*args,
|
||||
|
||||
def ScaledConv1d(*args,
|
||||
initial_scale: float = 1.0,
|
||||
**kwargs ) -> nn.Linear:
|
||||
**kwargs ) -> nn.Conv1d:
|
||||
"""
|
||||
Behaves like a constructor of a modified version of nn.Conv1d
|
||||
that gives an easy way to set the default initial parameter scale.
|
||||
@ -497,6 +502,11 @@ class ActivationBalancer(torch.nn.Module):
|
||||
min_prob: float = 0.1,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
|
||||
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
||||
# loop.
|
||||
self.batch_count = 0
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.min_positive = min_positive
|
||||
@ -508,11 +518,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.sign_gain_factor = sign_gain_factor
|
||||
self.scale_gain_factor = scale_gain_factor
|
||||
|
||||
# count measures how many times the forward() function has been called.
|
||||
# We occasionally sync this to a tensor called `count`, that exists to
|
||||
# make sure it is synced to disk when we load and save the model.
|
||||
self.cpu_count = 0
|
||||
self.register_buffer('count', torch.tensor(0, dtype=torch.int64))
|
||||
|
||||
|
||||
|
||||
|
||||
@ -520,19 +526,9 @@ class ActivationBalancer(torch.nn.Module):
|
||||
if torch.jit.is_scripting() or not x.requires_grad:
|
||||
return _no_op(x)
|
||||
|
||||
count = self.cpu_count
|
||||
self.cpu_count += 1
|
||||
|
||||
if random.random() < 0.01:
|
||||
# Occasionally sync self.cpu_count with self.count.
|
||||
# count affects the decay of 'prob'. don't do this on every iter,
|
||||
# because syncing with the GPU is slow.
|
||||
self.cpu_count = max(self.cpu_count, self.count.item())
|
||||
self.count.fill_(self.cpu_count)
|
||||
|
||||
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
||||
# a floor at min_prob (==0.1, by default)
|
||||
prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0)))
|
||||
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||
|
||||
if random.random() < prob:
|
||||
sign_gain_factor = 0.5
|
||||
@ -890,7 +886,7 @@ class MaxEig(torch.nn.Module):
|
||||
|
||||
def _find_direction_coeffs(self,
|
||||
x: Tensor,
|
||||
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
prev_direction: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Figure out (an approximation to) the proportion of the variance of a set of
|
||||
feature vectors that can be attributed to the top eigen-direction.
|
||||
|
||||
@ -781,7 +781,7 @@ class DownsamplingZipformerEncoder(nn.Module):
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> Tensor:
|
||||
r"""Downsample, go through encoder, upsample.
|
||||
|
||||
Args:
|
||||
@ -914,7 +914,7 @@ class SimpleCombiner(torch.nn.Module):
|
||||
def __init__(self,
|
||||
dim1: int,
|
||||
dim2: int,
|
||||
min_weight: Tuple[float] = (0., 0.)):
|
||||
min_weight: Tuple[float, float] = (0., 0.)):
|
||||
super(SimpleCombiner, self).__init__()
|
||||
assert dim2 >= dim1
|
||||
initial_weight1 = 0.1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user