mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add 1d squeeze and excite (-like) module in Conv2dSubsampling
This commit is contained in:
parent
dd3826104e
commit
534eca4bf3
@ -29,6 +29,76 @@ from torch import Tensor
|
||||
from torch.nn import Embedding as ScaledEmbedding
|
||||
|
||||
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||
|
||||
It is a floating point value whose value changes depending on the batch count of the
|
||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
Example:
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||
|
||||
`default` is used when self.batch_count is not set or in training or mode or in
|
||||
torch.jit scripting mode.
|
||||
"""
|
||||
def __init__(self,
|
||||
*args,
|
||||
default: float = 0.0):
|
||||
super().__init__()
|
||||
# self.batch_count and self.name will be written to in the training loop.
|
||||
self.batch_count = None
|
||||
self.name = None
|
||||
self.default = default
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
for i in range(len(args) - 1):
|
||||
assert args[i + 1] > args[i], args
|
||||
self.schedule = args
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||
self.schedule)
|
||||
|
||||
def __float__(self):
|
||||
print_prob = 0.0002
|
||||
def maybe_print(ans):
|
||||
if random.random() < print_prob:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
return float(self.default)
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
ans = self.schedule[0][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
ans = self.schedule[-1][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
next_x, next_y = self.schedule[i]
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
|
||||
FloatLike = Union[float, ScheduledFloat]
|
||||
|
||||
|
||||
|
||||
class ActivationBalancerFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@ -417,14 +487,14 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
max_positive: float = 0.95,
|
||||
max_factor: float = 0.04,
|
||||
sign_gain_factor: float = 0.01,
|
||||
scale_gain_factor: float = 0.02,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0,
|
||||
min_prob: float = 0.1,
|
||||
min_positive: FloatLike = 0.05,
|
||||
max_positive: FloatLike = 0.95,
|
||||
max_factor: FloatLike = 0.04,
|
||||
sign_gain_factor: FloatLike = 0.01,
|
||||
scale_gain_factor: FloatLike = 0.02,
|
||||
min_abs: FloatLike = 0.2,
|
||||
max_abs: FloatLike = 100.0,
|
||||
min_prob: FloatLike = 0.1,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
||||
@ -453,25 +523,26 @@ class ActivationBalancer(torch.nn.Module):
|
||||
|
||||
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
||||
# a floor at min_prob (==0.1, by default)
|
||||
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||
|
||||
if random.random() < prob:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
sign_gain_factor = 0.5
|
||||
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
||||
if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0:
|
||||
sign_factor = _compute_sign_factor(x, self.channel_dim,
|
||||
self.min_positive, self.max_positive,
|
||||
gain_factor=self.sign_gain_factor / prob,
|
||||
max_factor=self.max_factor)
|
||||
float(self.min_positive),
|
||||
float(self.max_positive),
|
||||
gain_factor=float(self.sign_gain_factor) / prob,
|
||||
max_factor=float(self.max_factor))
|
||||
else:
|
||||
sign_factor = None
|
||||
|
||||
|
||||
scale_factor = _compute_scale_factor(x, self.channel_dim,
|
||||
min_abs=self.min_abs,
|
||||
max_abs=self.max_abs,
|
||||
gain_factor=self.scale_gain_factor / prob,
|
||||
max_factor=self.max_factor)
|
||||
min_abs=float(self.min_abs),
|
||||
max_abs=float(self.max_abs),
|
||||
gain_factor=float(self.scale_gain_factor) / prob,
|
||||
max_factor=float(self.max_factor))
|
||||
return ActivationBalancerFunction.apply(
|
||||
x, scale_factor, sign_factor, self.channel_dim,
|
||||
)
|
||||
@ -519,74 +590,6 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
||||
|
||||
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||
|
||||
It is a floating point value whose value changes depending on the batch count of the
|
||||
training loop. It is a piecewise linear function where you specifiy the (x,y) pairs
|
||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
Example:
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||
|
||||
`default` is used when self.batch_count is not set or in training or mode or in
|
||||
torch.jit scripting mode.
|
||||
"""
|
||||
def __init__(self,
|
||||
*args,
|
||||
default: float = 0.0):
|
||||
super().__init__()
|
||||
# self.batch_count and self.name will be written to in the training loop.
|
||||
self.batch_count = None
|
||||
self.name = None
|
||||
self.default = default
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
for i in range(len(args) - 1):
|
||||
assert args[i + 1] > args[i], args
|
||||
self.schedule = args
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'batch_count={}, schedule={}'.format(self.batch_count,
|
||||
self.schedule)
|
||||
|
||||
def __float__(self):
|
||||
print_prob = 0.0002
|
||||
def maybe_print(ans):
|
||||
if random.random() < print_prob:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
return float(self.default)
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
ans = self.schedule[0][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
ans = self.schedule[-1][1]
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
next_x, next_y = self.schedule[i]
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
maybe_print(ans)
|
||||
return float(ans)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
|
||||
FloatLike = Union[float, ScheduledFloat]
|
||||
|
||||
|
||||
def _whitening_metric(x: Tensor,
|
||||
num_groups: int):
|
||||
"""
|
||||
|
||||
@ -1612,6 +1612,50 @@ class ConvolutionModule(nn.Module):
|
||||
x = x.permute(2, 0, 1) # (time, batch, channel)
|
||||
return x
|
||||
|
||||
|
||||
class SqueezeExcite1d(nn.Module):
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
bottleneck_channels: int):
|
||||
super().__init__()
|
||||
self.to_bottleneck_proj = nn.Conv1d(in_channels=channels,
|
||||
out_channels=bottleneck_channels,
|
||||
kernel_size=1)
|
||||
self.bottleneck_activation = TanSwish()
|
||||
self.from_bottleneck_proj = nn.Conv1d(in_channels=bottleneck_channels,
|
||||
out_channels=channels,
|
||||
kernel_size=1)
|
||||
self.balancer = ActivationBalancer(
|
||||
channels, channel_dim=1,
|
||||
min_abs=0.05,
|
||||
max_abs=ScheduledFloat((0.0, 0.2),
|
||||
(4000.0, 2.0),
|
||||
(10000.0, 10.0),
|
||||
default=1.0),
|
||||
max_factor=0.02,
|
||||
min_prob=0.1,
|
||||
)
|
||||
self.activation = nn.Sigmoid()
|
||||
|
||||
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
x: a Tensor of shape (batch_size, channels, T).
|
||||
Returns: something with the same shape as x.
|
||||
"""
|
||||
# would replace this mean with cumsum for a causal model.
|
||||
bottleneck = x.mean(dim=2, keepdim=True)
|
||||
|
||||
bottleneck = self.to_bottleneck_proj(bottleneck)
|
||||
bottleneck = self.bottleneck_activation(bottleneck)
|
||||
bottleneck = self.bottleneck_activation(bottleneck)
|
||||
scale = self.from_bottleneck_proj(bottleneck)
|
||||
scale = self.balancer(scale)
|
||||
scale = self.activation(scale)
|
||||
return x * scale
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
@ -1630,6 +1674,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
bottleneck_channels: int = 64,
|
||||
dropout: float = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
@ -1643,6 +1688,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
bottleneck:
|
||||
bottleneck dimension for 1d squeeze-excite
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
@ -1678,6 +1725,10 @@ class Conv2dSubsampling(nn.Module):
|
||||
DoubleSwish(),
|
||||
)
|
||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||
|
||||
self.squeeze_excite = SqueezeExcite1d(out_height * layer3_channels,
|
||||
bottleneck_channels)
|
||||
|
||||
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@ -1697,7 +1748,14 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).reshape(b, t, c * f))
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels))
|
||||
x = x.transpose(1, 2)
|
||||
x = self.squeeze_excite(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = self.out(x)
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user