From 3838b8431362b1b02ed6ebd7ae1ae24cfbf951bd Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 13 May 2022 22:23:39 +0800 Subject: [PATCH] add conv module --- .../emformer.py | 289 ++++++++++++++++++ .../test_emformer.py | 54 ++++ 2 files changed, 343 insertions(+) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 24ee8b0be..f3ee7b0f7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -153,9 +153,298 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(x), self.dropout(pos_emb) +class ConvolutionModule(nn.Module): + """ConvolutionModule. + + Modified from https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa + + Args: + chunk_length (int): + Length of each chunk. + right_context_length (int): + Length of right context. + channels (int): + The number of channels of conv layers. + kernel_size (int): + Kernerl size of conv layers. + bias (bool): + Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, + chunk_length: int, + right_context_length: int, + channels: int, + kernel_size: int, + bias: bool = True, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.chunk_length = chunk_length + self.right_context_length = right_context_length + self.channels = channels + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # After pointwise_conv1 we put x through a gated linear unit + # (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in + # the range 1 to 4, but sometimes, for some reason, for layer 0 the rms + # ends up being very large, between 50 and 100 for different channels. + # This will cause very peaky and sparse derivatives for the sigmoid + # gating function, which will tend to make the loss function not learn + # effectively. (for most layers the average absolute values are in the + # range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for + # different layers, which likely breaks down as 0.5 for the "linear" + # half and 0.2 to 0.3 for the part that goes into the sigmoid. + # The idea is that if we constrain the rms values to a reasonable range + # via a constraint of max_abs=10.0, it will be in a better position to + # start learning something, i.e. to latch onto the correct range. + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + # make it causal by padding cached (kernel_size - 1) frames on the left + self.cache_size = kernel_size - 1 + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25, + ) + + def _split_right_context( + self, + pad_utterance: torch.Tensor, + right_context: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + pad_utterance: + Its shape is (cache_size + U, B, D). + right_context: + Its shape is (R, B, D). + + Returns: + Right context segments padding with corresponding context. + Its shape is (num_segs * B, D, cache_size + right_context_length). + """ + U_, B, D = pad_utterance.size() + R = right_context.size(0) + assert self.right_context_length != 0 + assert R % self.right_context_length == 0 + num_chunks = R // self.right_context_length + right_context = right_context.reshape( + num_chunks, self.right_context_length, B, D + ) + right_context = right_context.permute(0, 2, 1, 3).reshape( + num_chunks * B, self.right_context_length, D + ) + padding = [] + for idx in range(num_chunks): + end_idx = min(U_, self.cache_size + (idx + 1) * self.chunk_length) + start_idx = end_idx - self.cache_size + padding.append(pad_utterance[start_idx:end_idx]) + padding = torch.cat(padding, dim=1).permute(1, 0, 2) + # (num_segs * B, cache_size, D) + pad_right_context = torch.cat([padding, right_context], dim=1) + # (num_segs * B, cache_size + right_context_length, D) + return pad_right_context.permute(0, 2, 1) + + def _merge_right_context( + self, right_context: torch.Tensor, B: int + ) -> torch.Tensor: + """ + Args: + right_context: + Right context segments. + It shape is (num_segs * B, D, right_context_length). + B: + Batch size. + + Returns: + A tensor of shape (B, D, R), where + R = num_segs * right_context_length. + """ + right_context = right_context.reshape( + -1, B, self.channels, self.right_context_length + ) + right_context = right_context.permute(1, 2, 0, 3) + right_context = right_context.reshape(B, self.channels, -1) + return right_context + + def forward( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Causal convolution module applied on both utterance and right_context. + + Args: + utterance (torch.Tensor): + Utterance tensor of shape (U, B, D). + right_context (torch.Tensor): + Right context tensor of shape (R, B, D). + cache (torch.Tensor, optional): + Cached tensor for left padding of shape (B, D, cache_size). + + Returns: + A tuple of 3 tensors: + - output utterance of shape (U, B, D). + - output right_context of shape (R, B, D). + - updated cache tensor of shape (B, D, cache_size). + """ + U, B, D = utterance.size() + R, _, _ = right_context.size() + + # point-wise conv and GLU mechanism + x = torch.cat([right_context, utterance], dim=0) # (R + U, B, D) + x = x.permute(1, 2, 0) # (B, D, R + U) + x = self.pointwise_conv1(x) # (B, 2 * D, R + U) + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (B, D, R + U) + utterance = x[:, :, R:] # (B, D, U) + right_context = x[:, :, :R] # (B, D, R) + + if cache is None: + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + else: + assert cache.shape == (B, D, self.cache_size), cache.shape + pad_utterance = torch.cat( + [cache, utterance], dim=2 + ) # (B, D, cache + U) + # update cache + new_cache = pad_utterance[:, :, -self.cache_size :] + + # depth-wise conv on utterance + utterance = self.depthwise_conv(pad_utterance) # (B, D, U) + + if self.right_context_length > 0: + # depth-wise conv on right_context + pad_right_context = self._split_right_context( + pad_utterance.permute(2, 0, 1), right_context.permute(2, 0, 1) + ) # (num_segs * B, D, cache_size + right_context_length) + right_context = self.depthwise_conv( + pad_right_context + ) # (num_segs * B, D, right_context_length) + right_context = self._merge_right_context( + right_context, B + ) # (B, D, R) + + x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) + x = self.deriv_balancer2(x) + x = self.activation(x) + + # point-wise conv + x = self.pointwise_conv2(x) # (B, D, R + U) + + right_context = x[:, :, :R] # (B, D, R) + utterance = x[:, :, R:] # (B, D, U) + return ( + utterance.permute(2, 0, 1), + right_context.permute(2, 0, 1), + new_cache, + ) + + def infer( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Causal convolution module applied on both utterance and right_context. + + Args: + utterance (torch.Tensor): + Utterance tensor of shape (U, B, D). + right_context (torch.Tensor): + Right context tensor of shape (R, B, D). + cache (torch.Tensor, optional): + Cached tensor for left padding of shape (B, D, cache_size). + + Returns: + A tuple of 3 tensors: + - output utterance of shape (U, B, D). + - output right_context of shape (R, B, D). + - updated cache tensor of shape (B, D, cache_size). + """ + U, B, D = utterance.size() + R, _, _ = right_context.size() + + # point-wise conv + x = torch.cat([utterance, right_context], dim=0) # (U + R, B, D) + x = x.permute(1, 2, 0) # (B, D, U + R) + x = self.pointwise_conv1(x) # (B, 2 * D, U + R) + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (B, D, U + R) + + if cache is None: + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + else: + assert cache.shape == (B, D, self.cache_size), cache.shape + x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R) + # update cache + new_cache = x[:, :, -R - self.cache_size:-R] + + # 1-D depth-wise conv + x = self.depthwise_conv(x) # (B, D, U + R) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + # point-wise conv + x = self.pointwise_conv2(x) # (B, D, U + R) + + utterance = x[:, :, :U] # (B, D, U) + right_context = x[:, :, U:] # (B, D, R) + return ( + utterance.permute(2, 0, 1), + right_context.permute(2, 0, 1), + new_cache, + ) + + class EmformerAttention(nn.Module): r"""Emformer layer attention module. + Relative positional encoding is applied in this module, which is difference + from https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py # noqa + Args: embed_dim (int): Embedding dimension. diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py index 03835f0d7..4549dad22 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -102,7 +102,61 @@ def test_emformer_attention_infer(): assert next_val.shape == (L + U, B, D) +def test_convolution_module_forward(): + from emformer import ConvolutionModule + + B, D = 2, 256 + chunk_length = 4 + right_context_length = 2 + num_chunks = 3 + U = num_chunks * chunk_length + R = num_chunks * right_context_length + kernel_size = 31 + conv_module = ConvolutionModule( + chunk_length, right_context_length, D, kernel_size, + ) + + utterance = torch.randn(U, B, D) + right_context = torch.randn(R, B, D) + cache = torch.randn(B, D, kernel_size - 1) + + utterance, right_context, new_cache = conv_module( + utterance, right_context, cache + ) + assert utterance.shape == (U, B, D) + assert right_context.shape == (R, B, D) + assert new_cache.shape == (B, D, kernel_size - 1) + + +def test_convolution_module_infer(): + from emformer import ConvolutionModule + + B, D = 2, 256 + chunk_length = 4 + right_context_length = 2 + num_chunks = 1 + U = num_chunks * chunk_length + R = num_chunks * right_context_length + kernel_size = 31 + conv_module = ConvolutionModule( + chunk_length, right_context_length, D, kernel_size, + ) + + utterance = torch.randn(U, B, D) + right_context = torch.randn(R, B, D) + cache = torch.randn(B, D, kernel_size - 1) + + utterance, right_context, new_cache = conv_module.infer( + utterance, right_context, cache + ) + assert utterance.shape == (U, B, D) + assert right_context.shape == (R, B, D) + assert new_cache.shape == (B, D, kernel_size - 1) + + if __name__ == "__main__": test_rel_positional_encoding() test_emformer_attention_forward() test_emformer_attention_infer() + test_convolution_module_forward() + test_convolution_module_infer()