From c2808f8541371d55e1cbd9d3129d0851c7295d92 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 12 Apr 2022 20:13:51 +0800 Subject: [PATCH] Support cache of left context for causal convolution. --- .../ASR/conv_emformer_transducer/emformer.py | 153 +++++++++++++----- .../conv_emformer_transducer/test_emformer.py | 32 ++-- .../ASR/conv_emformer_transducer/train.py | 2 +- 3 files changed, 134 insertions(+), 53 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index 5ac65141e..e9ce56aa7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -601,24 +601,8 @@ class EmformerLayer(nn.Module): ) return right_context_utterance - def _apply_conv_module( - self, - right_context_utterance: torch.Tensor, - right_context_end_idx: int, - ) -> torch.Tensor: - """Apply convolution module on utterance.""" - utterance = right_context_utterance[right_context_end_idx:] - right_context = right_context_utterance[:right_context_end_idx] - - residual = utterance - utterance = self.norm_conv(utterance) - utterance = residual + self.dropout(self.conv_module(utterance)) - right_context_utterance = torch.cat([right_context, utterance]) - return right_context_utterance - def _apply_feed_forward_module( - self, - right_context_utterance: torch.Tensor, + self, right_context_utterance: torch.Tensor ) -> torch.Tensor: """Apply feed forward module.""" residual = right_context_utterance @@ -628,6 +612,39 @@ class EmformerLayer(nn.Module): ) return right_context_utterance + def _apply_conv_module_forward( + self, + right_context_utterance: torch.Tensor, + right_context_end_idx: int, + ) -> torch.Tensor: + """Apply convolution module on utterance in non-infer mode.""" + utterance = right_context_utterance[right_context_end_idx:] + right_context = right_context_utterance[:right_context_end_idx] + + residual = utterance + utterance = self.norm_conv(utterance) + utterance, _ = self.conv_module(utterance) + utterance = residual + self.dropout(utterance) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance + + def _apply_conv_module_infer( + self, + right_context_utterance: torch.Tensor, + right_context_end_idx: int, + conv_cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply convolution module on utterance in infer mode.""" + utterance = right_context_utterance[right_context_end_idx:] + right_context = right_context_utterance[:right_context_end_idx] + + residual = utterance + utterance = self.norm_conv(utterance) + utterance, conv_cache = self.conv_module(utterance, conv_cache) + utterance = residual + self.dropout(utterance) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance, conv_cache + def _apply_attention_module_forward( self, right_context_utterance: torch.Tensor, @@ -790,7 +807,7 @@ class EmformerLayer(nn.Module): attention_mask, ) - right_context_utterance = self._apply_conv_module( + right_context_utterance = self._apply_conv_module_forward( right_context_utterance, right_context_end_idx ) @@ -812,6 +829,7 @@ class EmformerLayer(nn.Module): right_context: torch.Tensor, memory: torch.Tensor, state: Optional[List[torch.Tensor]] = None, + conv_cache: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: """Forward pass for inference. @@ -841,6 +859,8 @@ class EmformerLayer(nn.Module): state (List[torch.Tensor], optional): List of tensors representing layer internal state generated in preceding computation. (default=None) + conv_cache (torch.Tensor, optional): + Cache tensor of left context for causal convolution. Returns: (Tensor, Tensor, List[torch.Tensor], Tensor): @@ -848,6 +868,7 @@ class EmformerLayer(nn.Module): - output right_context, with shape (R, B, D); - output memory, with shape (1, B, D) or (0, B, D). - output state. + - updated conv_cache. """ right_context_utterance = torch.cat([right_context, utterance]) right_context_end_idx = right_context.size(0) @@ -868,8 +889,10 @@ class EmformerLayer(nn.Module): state, ) - right_context_utterance = self._apply_conv_module( - right_context_utterance, right_context_end_idx + right_context_utterance, conv_cache = self._apply_conv_module_infer( + right_context_utterance, + right_context_end_idx, + conv_cache, ) right_context_utterance = self._apply_feed_forward_module( @@ -885,6 +908,7 @@ class EmformerLayer(nn.Module): output_right_context, output_memory, output_state, + conv_cache, ) @@ -1156,7 +1180,10 @@ class EmformerEncoder(nn.Module): x: torch.Tensor, lengths: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + conv_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[ + torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor] + ]: """Forward pass for streaming inference. B: batch size; @@ -1173,15 +1200,18 @@ class EmformerEncoder(nn.Module): right_context at the end. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each - element (List[torch.Tensor]) corresponding to each emformer layer. + element (List[torch.Tensor]) corresponds to each emformer layer. (default: None) - + conv_caches (List[torch.Tensor], optional): + Cached tensors of left context for causal convolution, where each + element (Tensor) corresponds to each convolutional layer. Returns: - (Tensor, Tensor, List[List[torch.Tensor]]): + (Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]): - output utterance frames, with shape (U, B, D). - output lengths, with shape (B,), without containing the right_context at the end. - updated states from current chunk's computation. + - updated convolution caches from current chunk. """ assert x.size(0) == self.chunk_length + self.right_context_length, ( "Per configured chunk_length and right_context_length, " @@ -1199,17 +1229,26 @@ class EmformerEncoder(nn.Module): ) output = utterance output_states: List[List[torch.Tensor]] = [] + output_conv_caches: List[torch.Tensor] = [] for layer_idx, layer in enumerate(self.emformer_layers): - output, right_context, memory, output_state = layer.infer( + ( + output, + right_context, + memory, + output_state, + output_conv_cache, + ) = layer.infer( output, output_lengths, right_context, memory, None if states is None else states[layer_idx], + None if conv_caches is None else conv_caches[layer_idx], ) output_states.append(output_state) + output_conv_caches.append(output_conv_cache) - return output, output_lengths, output_states + return output, output_lengths, output_states, output_conv_caches class Emformer(EncoderInterface): @@ -1328,6 +1367,7 @@ class Emformer(EncoderInterface): x: torch.Tensor, x_lens: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None, + conv_caches: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: """Forward pass for streaming inference. @@ -1345,8 +1385,11 @@ class Emformer(EncoderInterface): right_context at the end. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each - element (List[torch.Tensor]) corresponding to each emformer layer. + element (List[torch.Tensor]) corresponds to each emformer layer. (default: None) + conv_caches (List[torch.Tensor], optional): + Cached tensors of left context for causal convolution, where each + element (Tensor) corresponds to each convolutional layer. Returns: (Tensor, Tensor): - output logits, with shape (B, T', D), where @@ -1354,6 +1397,7 @@ class Emformer(EncoderInterface): - logits lengths, with shape (B,), without containing the right_context at the end. - updated states from current chunk's computation. + - updated convolution caches from current chunk. """ x = self.encoder_embed(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) @@ -1364,14 +1408,17 @@ class Emformer(EncoderInterface): x_lens = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == x_lens.max().item() - output, output_lengths, output_states = self.encoder.infer( - x, x_lens, states - ) # (T, N, C) + ( + output, + output_lengths, + output_states, + output_conv_caches, + ) = self.encoder.infer(x, x_lens, states, conv_caches) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return logits, output_lengths, output_states + return logits, output_lengths, output_states, output_conv_caches class ConvolutionModule(nn.Module): @@ -1437,28 +1484,50 @@ class ConvolutionModule(nn.Module): ) self.activation = Swish() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """Compute convolution module. Args: - x: Input tensor (#time, batch, channels). - + x (torch.Tensor): + Input tensor (#time, batch, channels). + cache (torch.Tensor, optional): + Cached tensor for left padding (#batch, channels, cache_time). Returns: - Tensor: Output tensor (#time, batch, channels). - + A tuple of 2 tensors: + - output tensor (#time, batch, channels). + - updated cache tensor (#batch, channels, cache_time). """ # exchange the temporal dimension and the feature dimension x = x.permute(1, 2, 0) # (#batch, channels, time). - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - # 1D Depthwise Conv if self.left_padding > 0: # manualy padding self.lorder zeros to the left # make depthwise_conv causal - x = nn.functional.pad(x, (self.left_padding, 0), "constant", 0.0) + if cache is None: + x = nn.functional.pad( + x, (self.left_padding, 0), "constant", 0.0 + ) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + assert cache.size(2) == self.left_padding + x = torch.cat([cache, x], dim=2) + new_cache = x[:, :, x.size(2) - self.left_padding :] # noqa + else: + # It's better we just return None if no cache is requried, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = None + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + x = self.depthwise_conv(x) # x is (batch, channels, time) x = x.permute(0, 2, 1) @@ -1469,7 +1538,7 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return x.permute(2, 0, 1) + return x.permute(2, 0, 1), new_cache class Swish(torch.nn.Module): diff --git a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py index 7685bfb26..41e911e17 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py @@ -133,6 +133,7 @@ def test_emformer_layer_infer(): R, L = 2, 5 chunk_length = 2 U = chunk_length + K = 3 for use_memory in [True, False]: if use_memory: @@ -145,7 +146,7 @@ def test_emformer_layer_infer(): nhead=8, dim_feedforward=1024, chunk_length=chunk_length, - cnn_module_kernel=3, + cnn_module_kernel=K, left_context_length=L, max_memory_size=M, causal=True, @@ -157,17 +158,15 @@ def test_emformer_layer_infer(): right_context = torch.randn(R, B, D) memory = torch.randn(M, B, D) state = None + conv_cache = None ( output_utterance, output_right_context, output_memory, output_state, + output_conv_cache, ) = layer.infer( - utterance, - lengths, - right_context, - memory, - state, + utterance, lengths, right_context, memory, state, conv_cache ) assert output_utterance.shape == (U, B, D) assert output_right_context.shape == (R, B, D) @@ -180,6 +179,7 @@ def test_emformer_layer_infer(): assert output_state[1].shape == (L, B, D) assert output_state[2].shape == (L, B, D) assert output_state[3].shape == (1, B) + assert output_conv_cache.shape == (B, D, K - 1) def test_emformer_encoder_forward(): @@ -226,6 +226,7 @@ def test_emformer_encoder_infer(): U = chunk_length num_chunks = 3 num_encoder_layers = 2 + K = 3 for use_memory in [True, False]: if use_memory: @@ -238,7 +239,7 @@ def test_emformer_encoder_infer(): d_model=D, dim_feedforward=1024, num_encoder_layers=num_encoder_layers, - cnn_module_kernel=3, + cnn_module_kernel=K, left_context_length=L, right_context_length=R, max_memory_size=M, @@ -246,11 +247,14 @@ def test_emformer_encoder_infer(): ) states = None + conv_caches = None for chunk_idx in range(num_chunks): x = torch.randn(U + R, B, D) lengths = torch.randint(1, U + R + 1, (B,)) lengths[0] = U + R - output, output_lengths, states = encoder.infer(x, lengths, states) + output, output_lengths, states, conv_caches = encoder.infer( + x, lengths, states, conv_caches + ) assert output.shape == (U, B, D) assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) assert len(states) == num_encoder_layers @@ -262,6 +266,8 @@ def test_emformer_encoder_infer(): assert torch.equal( state[3], (chunk_idx + 1) * U * torch.ones_like(state[3]) ) + for conv_cache in conv_caches: + assert conv_cache.shape == (B, D, K - 1) def test_emformer_forward(): @@ -312,6 +318,7 @@ def test_emformer_infer(): B, D = 2, 256 num_chunks = 3 num_encoder_layers = 2 + K = 3 for use_memory in [True, False]: if use_memory: M = 3 @@ -324,7 +331,7 @@ def test_emformer_infer(): subsampling_factor=4, d_model=D, num_encoder_layers=num_encoder_layers, - cnn_module_kernel=3, + cnn_module_kernel=K, left_context_length=L, right_context_length=R, max_memory_size=M, @@ -332,11 +339,14 @@ def test_emformer_infer(): causal=True, ) states = None + conv_caches = None for chunk_idx in range(num_chunks): x = torch.randn(B, U + R + 3, num_features) x_lens = torch.randint(1, U + R + 3 + 1, (B,)) x_lens[0] = U + R + 3 - logits, output_lengths, states = model.infer(x, x_lens, states) + logits, output_lengths, states, conv_caches = model.infer( + x, x_lens, states, conv_caches + ) assert logits.shape == (B, U // 4, output_dim) assert torch.equal( output_lengths, @@ -352,6 +362,8 @@ def test_emformer_infer(): state[3], U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]), ) + for conv_cache in conv_caches: + assert conv_cache.shape == (B, D, K - 1) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/conv_emformer_transducer/train.py b/egs/librispeech/ASR/conv_emformer_transducer/train.py index d0126bb94..8a0eecc6b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/train.py @@ -139,7 +139,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--causal-conv", - type=bool, + type=str2bool, default=True, help="Whether use causal convolution.", )