Add 1d squeeze and excite (-like) module in Conv2dSubsampling

This commit is contained in:
Daniel Povey 2022-11-24 16:18:40 +08:00
parent dd3826104e
commit 534eca4bf3
2 changed files with 147 additions and 86 deletions

View File

@ -29,6 +29,76 @@ from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding
class ScheduledFloat(torch.nn.Module):
"""
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
it does not have a working forward() function. You are supposed to cast it to float, as
in, float(parent_module.whatever), and use it as something like a dropout prob.
It is a floating point value whose value changes depending on the batch count of the
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
in sorted order on x; x corresponds to the batch index. For batch-index values before the
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
`default` is used when self.batch_count is not set or in training or mode or in
torch.jit scripting mode.
"""
def __init__(self,
*args,
default: float = 0.0):
super().__init__()
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = None
self.name = None
self.default = default
assert len(args) >= 1
for (x,y) in args:
assert x >= 0
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.schedule = args
def extra_repr(self) -> str:
return 'batch_count={}, schedule={}'.format(self.batch_count,
self.schedule)
def __float__(self):
print_prob = 0.0002
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting():
return float(self.default)
if batch_count <= self.schedule[0][0]:
ans = self.schedule[0][1]
maybe_print(ans)
return float(ans)
elif batch_count >= self.schedule[-1][0]:
ans = self.schedule[-1][1]
maybe_print(ans)
return float(ans)
else:
cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)):
next_x, next_y = self.schedule[i]
if batch_count >= cur_x and batch_count <= next_x:
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans)
return float(ans)
cur_x, cur_y = next_x, next_y
assert False
FloatLike = Union[float, ScheduledFloat]
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(
@ -417,14 +487,14 @@ class ActivationBalancer(torch.nn.Module):
self,
num_channels: int,
channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.04,
sign_gain_factor: float = 0.01,
scale_gain_factor: float = 0.02,
min_abs: float = 0.2,
max_abs: float = 100.0,
min_prob: float = 0.1,
min_positive: FloatLike = 0.05,
max_positive: FloatLike = 0.95,
max_factor: FloatLike = 0.04,
sign_gain_factor: FloatLike = 0.01,
scale_gain_factor: FloatLike = 0.02,
min_abs: FloatLike = 0.2,
max_abs: FloatLike = 100.0,
min_prob: FloatLike = 0.1,
):
super(ActivationBalancer, self).__init__()
# CAUTION: this code expects self.batch_count to be overwritten in the main training
@ -453,25 +523,26 @@ class ActivationBalancer(torch.nn.Module):
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
if random.random() < prob:
assert x.shape[self.channel_dim] == self.num_channels
sign_gain_factor = 0.5
if self.min_positive != 0.0 or self.max_positive != 1.0:
if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0:
sign_factor = _compute_sign_factor(x, self.channel_dim,
self.min_positive, self.max_positive,
gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor)
float(self.min_positive),
float(self.max_positive),
gain_factor=float(self.sign_gain_factor) / prob,
max_factor=float(self.max_factor))
else:
sign_factor = None
scale_factor = _compute_scale_factor(x, self.channel_dim,
min_abs=self.min_abs,
max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor)
min_abs=float(self.min_abs),
max_abs=float(self.max_abs),
gain_factor=float(self.scale_gain_factor) / prob,
max_factor=float(self.max_factor))
return ActivationBalancerFunction.apply(
x, scale_factor, sign_factor, self.channel_dim,
)
@ -519,74 +590,6 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
class ScheduledFloat(torch.nn.Module):
"""
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
it does not have a working forward() function. You are supposed to cast it to float, as
in, float(parent_module.whatever), and use it as something like a dropout prob.
It is a floating point value whose value changes depending on the batch count of the
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
in sorted order on x; x corresponds to the batch index. For batch-index values before the
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
`default` is used when self.batch_count is not set or in training or mode or in
torch.jit scripting mode.
"""
def __init__(self,
*args,
default: float = 0.0):
super().__init__()
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = None
self.name = None
self.default = default
assert len(args) >= 1
for (x,y) in args:
assert x >= 0
for i in range(len(args) - 1):
assert args[i + 1] > args[i], args
self.schedule = args
def extra_repr(self) -> str:
return 'batch_count={}, schedule={}'.format(self.batch_count,
self.schedule)
def __float__(self):
print_prob = 0.0002
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting():
return float(self.default)
if batch_count <= self.schedule[0][0]:
ans = self.schedule[0][1]
maybe_print(ans)
return float(ans)
elif batch_count >= self.schedule[-1][0]:
ans = self.schedule[-1][1]
maybe_print(ans)
return float(ans)
else:
cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)):
next_x, next_y = self.schedule[i]
if batch_count >= cur_x and batch_count <= next_x:
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans)
return float(ans)
cur_x, cur_y = next_x, next_y
assert False
FloatLike = Union[float, ScheduledFloat]
def _whitening_metric(x: Tensor,
num_groups: int):
"""

View File

@ -1612,6 +1612,50 @@ class ConvolutionModule(nn.Module):
x = x.permute(2, 0, 1) # (time, batch, channel)
return x
class SqueezeExcite1d(nn.Module):
def __init__(self,
channels: int,
bottleneck_channels: int):
super().__init__()
self.to_bottleneck_proj = nn.Conv1d(in_channels=channels,
out_channels=bottleneck_channels,
kernel_size=1)
self.bottleneck_activation = TanSwish()
self.from_bottleneck_proj = nn.Conv1d(in_channels=bottleneck_channels,
out_channels=channels,
kernel_size=1)
self.balancer = ActivationBalancer(
channels, channel_dim=1,
min_abs=0.05,
max_abs=ScheduledFloat((0.0, 0.2),
(4000.0, 2.0),
(10000.0, 10.0),
default=1.0),
max_factor=0.02,
min_prob=0.1,
)
self.activation = nn.Sigmoid()
def forward(self, x: Tensor):
"""
x: a Tensor of shape (batch_size, channels, T).
Returns: something with the same shape as x.
"""
# would replace this mean with cumsum for a causal model.
bottleneck = x.mean(dim=2, keepdim=True)
bottleneck = self.to_bottleneck_proj(bottleneck)
bottleneck = self.bottleneck_activation(bottleneck)
bottleneck = self.bottleneck_activation(bottleneck)
scale = self.from_bottleneck_proj(bottleneck)
scale = self.balancer(scale)
scale = self.activation(scale)
return x * scale
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
@ -1630,6 +1674,7 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
bottleneck_channels: int = 64,
dropout: float = 0.1,
) -> None:
"""
@ -1643,6 +1688,8 @@ class Conv2dSubsampling(nn.Module):
Number of channels in layer1
layer1_channels:
Number of channels in layer2
bottleneck:
bottleneck dimension for 1d squeeze-excite
"""
assert in_channels >= 7
super().__init__()
@ -1678,6 +1725,10 @@ class Conv2dSubsampling(nn.Module):
DoubleSwish(),
)
out_height = (((in_channels - 1) // 2) - 1) // 2
self.squeeze_excite = SqueezeExcite1d(out_height * layer3_channels,
bottleneck_channels)
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
self.dropout = nn.Dropout(dropout)
@ -1697,7 +1748,14 @@ class Conv2dSubsampling(nn.Module):
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).reshape(b, t, c * f))
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels))
x = x.transpose(1, 2)
x = self.squeeze_excite(x)
x = x.transpose(1, 2)
x = self.out(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.dropout(x)
return x