From 534eca4bf383117ad62b694e0fce9a887c40248c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Nov 2022 16:18:40 +0800 Subject: [PATCH 1/2] Add 1d squeeze and excite (-like) module in Conv2dSubsampling --- .../pruned_transducer_stateless7/scaling.py | 173 +++++++++--------- .../pruned_transducer_stateless7/zipformer.py | 60 +++++- 2 files changed, 147 insertions(+), 86 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 93d6d631b..fb197a061 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -29,6 +29,76 @@ from torch import Tensor from torch.nn import Embedding as ScaledEmbedding + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specifiy the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or in training or mode or in + torch.jit scripting mode. + """ + def __init__(self, + *args, + default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + assert len(args) >= 1 + for (x,y) in args: + assert x >= 0 + for i in range(len(args) - 1): + assert args[i + 1] > args[i], args + self.schedule = args + + def extra_repr(self) -> str: + return 'batch_count={}, schedule={}'.format(self.batch_count, + self.schedule) + + def __float__(self): + print_prob = 0.0002 + def maybe_print(ans): + if random.random() < print_prob: + logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + batch_count = self.batch_count + if batch_count is None or not self.training or torch.jit.is_scripting(): + return float(self.default) + if batch_count <= self.schedule[0][0]: + ans = self.schedule[0][1] + maybe_print(ans) + return float(ans) + elif batch_count >= self.schedule[-1][0]: + ans = self.schedule[-1][1] + maybe_print(ans) + return float(ans) + else: + cur_x, cur_y = self.schedule[0] + for i in range(1, len(self.schedule)): + next_x, next_y = self.schedule[i] + if batch_count >= cur_x and batch_count <= next_x: + ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) + maybe_print(ans) + return float(ans) + cur_x, cur_y = next_x, next_y + assert False + + +FloatLike = Union[float, ScheduledFloat] + + + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( @@ -417,14 +487,14 @@ class ActivationBalancer(torch.nn.Module): self, num_channels: int, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + max_factor: FloatLike = 0.04, + sign_gain_factor: FloatLike = 0.01, + scale_gain_factor: FloatLike = 0.02, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + min_prob: FloatLike = 0.1, ): super(ActivationBalancer, self).__init__() # CAUTION: this code expects self.batch_count to be overwritten in the main training @@ -453,25 +523,26 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0))) + prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0))) if random.random() < prob: assert x.shape[self.channel_dim] == self.num_channels sign_gain_factor = 0.5 - if self.min_positive != 0.0 or self.max_positive != 1.0: + if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0: sign_factor = _compute_sign_factor(x, self.channel_dim, - self.min_positive, self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor) + float(self.min_positive), + float(self.max_positive), + gain_factor=float(self.sign_gain_factor) / prob, + max_factor=float(self.max_factor)) else: sign_factor = None scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor) + min_abs=float(self.min_abs), + max_abs=float(self.max_abs), + gain_factor=float(self.scale_gain_factor) / prob, + max_factor=float(self.max_factor)) return ActivationBalancerFunction.apply( x, scale_factor, sign_factor, self.channel_dim, ) @@ -519,74 +590,6 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specifiy the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or in training or mode or in - torch.jit scripting mode. - """ - def __init__(self, - *args, - default: float = 0.0): - super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - assert len(args) >= 1 - for (x,y) in args: - assert x >= 0 - for i in range(len(args) - 1): - assert args[i + 1] > args[i], args - self.schedule = args - - def extra_repr(self) -> str: - return 'batch_count={}, schedule={}'.format(self.batch_count, - self.schedule) - - def __float__(self): - print_prob = 0.0002 - def maybe_print(ans): - if random.random() < print_prob: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") - batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting(): - return float(self.default) - if batch_count <= self.schedule[0][0]: - ans = self.schedule[0][1] - maybe_print(ans) - return float(ans) - elif batch_count >= self.schedule[-1][0]: - ans = self.schedule[-1][1] - maybe_print(ans) - return float(ans) - else: - cur_x, cur_y = self.schedule[0] - for i in range(1, len(self.schedule)): - next_x, next_y = self.schedule[i] - if batch_count >= cur_x and batch_count <= next_x: - ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) - maybe_print(ans) - return float(ans) - cur_x, cur_y = next_x, next_y - assert False - - -FloatLike = Union[float, ScheduledFloat] - - def _whitening_metric(x: Tensor, num_groups: int): """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 44e024bce..8d41f527e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1612,6 +1612,50 @@ class ConvolutionModule(nn.Module): x = x.permute(2, 0, 1) # (time, batch, channel) return x + +class SqueezeExcite1d(nn.Module): + def __init__(self, + channels: int, + bottleneck_channels: int): + super().__init__() + self.to_bottleneck_proj = nn.Conv1d(in_channels=channels, + out_channels=bottleneck_channels, + kernel_size=1) + self.bottleneck_activation = TanSwish() + self.from_bottleneck_proj = nn.Conv1d(in_channels=bottleneck_channels, + out_channels=channels, + kernel_size=1) + self.balancer = ActivationBalancer( + channels, channel_dim=1, + min_abs=0.05, + max_abs=ScheduledFloat((0.0, 0.2), + (4000.0, 2.0), + (10000.0, 10.0), + default=1.0), + max_factor=0.02, + min_prob=0.1, + ) + self.activation = nn.Sigmoid() + + + + def forward(self, x: Tensor): + """ + x: a Tensor of shape (batch_size, channels, T). + Returns: something with the same shape as x. + """ + # would replace this mean with cumsum for a causal model. + bottleneck = x.mean(dim=2, keepdim=True) + + bottleneck = self.to_bottleneck_proj(bottleneck) + bottleneck = self.bottleneck_activation(bottleneck) + bottleneck = self.bottleneck_activation(bottleneck) + scale = self.from_bottleneck_proj(bottleneck) + scale = self.balancer(scale) + scale = self.activation(scale) + return x * scale + + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/2 length). @@ -1630,6 +1674,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, + bottleneck_channels: int = 64, dropout: float = 0.1, ) -> None: """ @@ -1643,6 +1688,8 @@ class Conv2dSubsampling(nn.Module): Number of channels in layer1 layer1_channels: Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite """ assert in_channels >= 7 super().__init__() @@ -1678,6 +1725,10 @@ class Conv2dSubsampling(nn.Module): DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 + + self.squeeze_excite = SqueezeExcite1d(out_height * layer3_channels, + bottleneck_channels) + self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) @@ -1697,7 +1748,14 @@ class Conv2dSubsampling(nn.Module): x = self.conv(x) # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels)) + x = x.transpose(1, 2) + x = self.squeeze_excite(x) + x = x.transpose(1, 2) + + x = self.out(x) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.dropout(x) return x From 0614f654288ae8a1101eadb44f0587eab56a74a3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Nov 2022 17:20:28 +0800 Subject: [PATCH 2/2] Bug fix, remove 2nd activation in a row --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 8d41f527e..110fb981a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1649,7 +1649,6 @@ class SqueezeExcite1d(nn.Module): bottleneck = self.to_bottleneck_proj(bottleneck) bottleneck = self.bottleneck_activation(bottleneck) - bottleneck = self.bottleneck_activation(bottleneck) scale = self.from_bottleneck_proj(bottleneck) scale = self.balancer(scale) scale = self.activation(scale)