From 13899dff5125ae8d7fe5c6f2dc5bf16c73e6e04c Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 6 Jun 2022 21:19:25 +0800 Subject: [PATCH] refactor, use fixed-length cache for batch decoding --- .../emformer.py | 367 +++++++++--------- 1 file changed, 181 insertions(+), 186 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 65c1b8ced..2e33e0053 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -200,7 +200,6 @@ class ConvolutionModule(nn.Module): self, utterance: torch.Tensor, right_context: torch.Tensor, - cache: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Causal convolution module. @@ -209,14 +208,11 @@ class ConvolutionModule(nn.Module): Utterance tensor of shape (U, B, D). right_context (torch.Tensor): Right context tensor of shape (R, B, D). - cache (torch.Tensor, optional): - Cached tensor for left padding of shape (B, D, cache_size). Returns: - A tuple of 3 tensors: - - output utterance of shape (U, B, D). - - output right_context of shape (R, B, D). - - updated cache tensor of shape (B, D, cache_size). + A tuple of 2 tensors: + - output utterance of shape (U, B, D). + - output right_context of shape (R, B, D). """ U, B, D = utterance.size() R, _, _ = right_context.size() @@ -230,17 +226,13 @@ class ConvolutionModule(nn.Module): utterance = x[:, :, R:] # (B, D, U) right_context = x[:, :, :R] # (B, D, R) - if cache is None: - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - else: - assert cache.shape == (B, D, self.cache_size), cache.shape + # make causal convolution + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) pad_utterance = torch.cat( [cache, utterance], dim=2 ) # (B, D, cache + U) - # update cache - new_cache = pad_utterance[:, :, -self.cache_size :] # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -269,7 +261,6 @@ class ConvolutionModule(nn.Module): return ( utterance.permute(2, 0, 1), right_context.permute(2, 0, 1), - new_cache, ) def infer( @@ -304,12 +295,8 @@ class ConvolutionModule(nn.Module): x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (B, D, U + R) - if cache is None: - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - else: - assert cache.shape == (B, D, self.cache_size), cache.shape + # make causal convolution + assert cache.shape == (B, D, self.cache_size), cache.shape x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R) # update cache new_cache = x[:, :, -R - self.cache_size : -R] @@ -383,7 +370,7 @@ class EmformerAttention(nn.Module): self, attention_weights: torch.Tensor, attention_mask: torch.Tensor, - padding_mask: Optional[torch.Tensor], + padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Given the entire attention weights, mask out unecessary connections and optionally with padding positions, to obtain underlying chunk-wise @@ -438,11 +425,11 @@ class EmformerAttention(nn.Module): def _forward_impl( self, utterance: torch.Tensor, - lengths: torch.Tensor, right_context: torch.Tensor, summary: torch.Tensor, memory: torch.Tensor, attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, left_context_key: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -470,7 +457,7 @@ class EmformerAttention(nn.Module): [value[: M + R], left_context_val, value[M + R :]] ) Q = query.size(0) - KV = key.size(0) + # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ tensor.contiguous() @@ -482,12 +469,6 @@ class EmformerAttention(nn.Module): reshaped_query * scaling, reshaped_key.transpose(1, 2) ) # (B * nhead, Q, KV) - # compute padding mask - if B == 1: - padding_mask = None - else: - padding_mask = make_pad_mask(KV - U + lengths) - # compute attention probabilities attention_probs = self._gen_attention_probs( attention_weights, attention_mask, padding_mask @@ -515,11 +496,11 @@ class EmformerAttention(nn.Module): def forward( self, utterance: torch.Tensor, - lengths: torch.Tensor, right_context: torch.Tensor, summary: torch.Tensor, memory: torch.Tensor, attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: Modify docs. """Forward pass for training and validation mode. @@ -560,9 +541,6 @@ class EmformerAttention(nn.Module): Args: utterance (torch.Tensor): Full utterance frames, with shape (U, B, D). - lengths (torch.Tensor): - With shape (B,) and i-th element representing - number of valid frames for i-th batch element in utterance. right_context (torch.Tensor): Hard-copied right context frames, with shape (R, B, D), where R = num_chunks * right_context_length @@ -575,6 +553,8 @@ class EmformerAttention(nn.Module): attention_mask (torch.Tensor): Pre-computed attention mask to simulate underlying chunk-wise attention, with shape (Q, KV). + padding_mask (torch.Tensor): + Padding mask of key tensor, with shape (B, KV). Returns: A tuple containing 2 tensors: @@ -588,23 +568,23 @@ class EmformerAttention(nn.Module): _, ) = self._forward_impl( utterance, - lengths, right_context, summary, memory, attention_mask, + padding_mask=padding_mask, ) return output_right_context_utterance, output_memory[:-1] def infer( self, utterance: torch.Tensor, - lengths: torch.Tensor, right_context: torch.Tensor, summary: torch.Tensor, memory: torch.Tensor, left_context_key: torch.Tensor, left_context_val: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass for inference. @@ -633,9 +613,6 @@ class EmformerAttention(nn.Module): Args: utterance (torch.Tensor): Current chunk frames, with shape (U, B, D), where U = chunk_length. - lengths (torch.Tensor): - With shape (B,) and i-th element representing - number of valid frames for i-th batch element in utterance. right_context (torch.Tensor): Right context frames, with shape (R, B, D), where R = right_context_length. @@ -645,10 +622,12 @@ class EmformerAttention(nn.Module): Memory vectors, with shape (M, B, D), or empty tensor. left_context_key (torch,Tensor): Cached attention key of left context from preceding computation, - with shape (L, B, D), where L <= left_context_length. + with shape (L, B, D). left_context_val (torch.Tensor): Cached attention value of left context from preceding computation, - with shape (L, B, D), where L <= left_context_length. + with shape (L, B, D). + padding_mask (torch.Tensor): + Padding mask of key tensor, with shape (B, KV). Returns: A tuple containing 4 tensors: @@ -665,6 +644,7 @@ class EmformerAttention(nn.Module): S = summary.size(0) M = memory.size(0) + # TODO: move it outside # query = [right context, utterance, summary] Q = R + U + S # key, value = [memory, right context, left context, uttrance] @@ -681,11 +661,11 @@ class EmformerAttention(nn.Module): value, ) = self._forward_impl( utterance, - lengths, right_context, summary, memory, attention_mask, + padding_mask=padding_mask, left_context_key=left_context_key, left_context_val=left_context_val, ) @@ -719,8 +699,8 @@ class EmformerEncoderLayer(nn.Module): Length of left context. (Default: 0) right_context_length (int, optional): Length of right context. (Default: 0) - max_memory_size (int, optional): - Maximum number of memory elements to use. (Default: 0) + memory_size (int, optional): + Number of memory elements to use. (Default: 0) tanh_on_mem (bool, optional): If ``True``, applies tanh to memory elements. (Default: ``False``) negative_inf (float, optional): @@ -738,7 +718,7 @@ class EmformerEncoderLayer(nn.Module): cnn_module_kernel: int = 31, left_context_length: int = 0, right_context_length: int = 0, - max_memory_size: int = 0, + memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, ): @@ -791,75 +771,29 @@ class EmformerEncoderLayer(nn.Module): self.layer_dropout = layer_dropout self.left_context_length = left_context_length self.chunk_length = chunk_length - self.max_memory_size = max_memory_size + self.memory_size = memory_size self.d_model = d_model - self.use_memory = max_memory_size > 0 + self.use_memory = memory_size > 0 - def _init_state( - self, batch_size: int, device: Optional[torch.device] - ) -> List[torch.Tensor]: - """Initialize states with zeros.""" - empty_memory = torch.zeros( - self.max_memory_size, batch_size, self.d_model, device=device - ) - left_context_key = torch.zeros( - self.left_context_length, batch_size, self.d_model, device=device - ) - left_context_val = torch.zeros( - self.left_context_length, batch_size, self.d_model, device=device - ) - past_length = torch.zeros( - 1, batch_size, dtype=torch.int32, device=device - ) - return [empty_memory, left_context_key, left_context_val, past_length] - - def _unpack_state( - self, state: List[torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Unpack cached states including: - 1) output memory from previous chunks in the lower layer; - 2) attention key and value of left context from preceding chunk's - computation. - """ - past_length = state[3][0][0].item() - past_left_context_length = min(self.left_context_length, past_length) - past_memory_length = min( - self.max_memory_size, math.ceil(past_length / self.chunk_length) - ) - memory_start_idx = self.max_memory_size - past_memory_length - pre_memory = state[0][memory_start_idx:] - left_context_start_idx = ( - self.left_context_length - past_left_context_length - ) - left_context_key = state[1][left_context_start_idx:] - left_context_val = state[2][left_context_start_idx:] - return pre_memory, left_context_key, left_context_val - - def _pack_state( + def _update_attn_cache( self, next_key: torch.Tensor, next_val: torch.Tensor, - update_length: int, memory: torch.Tensor, - state: List[torch.Tensor], + attn_cache: List[torch.Tensor], ) -> List[torch.Tensor]: - """Pack updated states including: + """Update cached attention state: 1) output memory of current chunk in the lower layer; 2) attention key and value in current chunk's computation, which would be resued in next chunk's computation. - 3) length of current chunk. """ - new_memory = torch.cat([state[0], memory]) - new_key = torch.cat([state[1], next_key]) - new_val = torch.cat([state[2], next_val]) - memory_start_idx = new_memory.size(0) - self.max_memory_size - state[0] = new_memory[memory_start_idx:] - key_start_idx = new_key.size(0) - self.left_context_length - state[1] = new_key[key_start_idx:] - val_start_idx = new_val.size(0) - self.left_context_length - state[2] = new_val[val_start_idx:] - state[3] = state[3] + update_length - return state + new_memory = torch.cat([attn_cache[0], memory]) + new_key = torch.cat([attn_cache[1], next_key]) + new_val = torch.cat([attn_cache[2], next_val]) + attn_cache[0] = new_memory[new_memory.size(0) - self.memory_size :] + attn_cache[1] = new_key[new_key.size(0) - self.left_context_length :] + attn_cache[2] = new_val[new_val.size(0) - self.left_context_length :] + return attn_cache def _apply_conv_module_forward( self, @@ -869,7 +803,7 @@ class EmformerEncoderLayer(nn.Module): """Apply convolution module in training and validation mode.""" utterance = right_context_utterance[R:] right_context = right_context_utterance[:R] - utterance, right_context, _ = self.conv_module(utterance, right_context) + utterance, right_context = self.conv_module(utterance, right_context) right_context_utterance = torch.cat([right_context, utterance]) return right_context_utterance @@ -892,15 +826,11 @@ class EmformerEncoderLayer(nn.Module): self, right_context_utterance: torch.Tensor, R: int, - lengths: torch.Tensor, memory: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply attention module in training and validation mode.""" - if attention_mask is None: - raise ValueError( - "attention_mask must be not None in training or validation mode." # noqa - ) utterance = right_context_utterance[R:] right_context = right_context_utterance[:R] @@ -914,11 +844,11 @@ class EmformerEncoderLayer(nn.Module): ) output_right_context_utterance, output_memory = self.attention( utterance=utterance, - lengths=lengths, right_context=right_context, summary=summary, memory=memory, attention_mask=attention_mask, + padding_mask=padding_mask, ) return output_right_context_utterance, output_memory @@ -927,9 +857,9 @@ class EmformerEncoderLayer(nn.Module): self, right_context_utterance: torch.Tensor, R: int, - lengths: torch.Tensor, memory: torch.Tensor, - state: Optional[List[torch.Tensor]] = None, + attn_cache: List[torch.Tensor], + padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: """Apply attention module in inference mode. 1) Unpack cached states including: @@ -937,7 +867,7 @@ class EmformerEncoderLayer(nn.Module): - attention key and value of left context from preceding chunk's compuation; 2) Apply attention computation; - 3) Pack updated states including: + 3) Update cached attention states including: - output memory of current chunk in the lower layer; - attention key and value in current chunk's computation, which would be resued in next chunk's computation. @@ -946,11 +876,10 @@ class EmformerEncoderLayer(nn.Module): utterance = right_context_utterance[R:] right_context = right_context_utterance[:R] - if state is None: - state = self._init_state(utterance.size(1), device=utterance.device) - pre_memory, left_context_key, left_context_val = self._unpack_state( - state - ) + pre_memory = attn_cache[0] + left_context_key = attn_cache[1] + left_context_val = attn_cache[2] + if self.use_memory: summary = self.summary_op(utterance.permute(1, 2, 0)).permute( 2, 0, 1 @@ -967,25 +896,25 @@ class EmformerEncoderLayer(nn.Module): next_val, ) = self.attention.infer( utterance=utterance, - lengths=lengths, right_context=right_context, summary=summary, memory=pre_memory, left_context_key=left_context_key, left_context_val=left_context_val, + padding_mask=padding_mask, ) - state = self._pack_state( - next_key, next_val, utterance.size(0), memory, state + attn_cache = self._update_attn_cache( + next_key, next_val, memory, attn_cache ) - return output_right_context_utterance, output_memory, state + return output_right_context_utterance, output_memory, attn_cache def forward( self, utterance: torch.Tensor, - lengths: torch.Tensor, right_context: torch.Tensor, memory: torch.Tensor, - attention_mask: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, warmup: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Forward pass for training and validation mode. @@ -999,9 +928,6 @@ class EmformerEncoderLayer(nn.Module): Args: utterance (torch.Tensor): Utterance frames, with shape (U, B, D). - lengths (torch.Tensor): - With shape (B,) and i-th element representing - number of valid frames for i-th batch element in utterance. right_context (torch.Tensor): Right context frames, with shape (R, B, D). memory (torch.Tensor): @@ -1010,6 +936,8 @@ class EmformerEncoderLayer(nn.Module): attention_mask (torch.Tensor): Attention mask for underlying attention module, with shape (Q, KV), where Q = R + U + S, KV = M + R + U. + padding_mask (torch.Tensor): + Padding mask of ker tensor, with shape (B, KV). Returns: A tuple containing 3 tensors: @@ -1038,7 +966,7 @@ class EmformerEncoderLayer(nn.Module): # emformer attention module src_att, output_memory = self._apply_attention_module_forward( - src, R, lengths, memory, attention_mask + src, R, memory, attention_mask, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1061,11 +989,11 @@ class EmformerEncoderLayer(nn.Module): def infer( self, utterance: torch.Tensor, - lengths: torch.Tensor, right_context: torch.Tensor, memory: torch.Tensor, - state: Optional[List[torch.Tensor]] = None, - conv_cache: Optional[torch.Tensor] = None, + attn_cache: List[torch.Tensor], + conv_cache: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: """Forward pass for inference. @@ -1078,18 +1006,17 @@ class EmformerEncoderLayer(nn.Module): Args: utterance (torch.Tensor): Utterance frames, with shape (U, B, D). - lengths (torch.Tensor): - With shape (B,) and i-th element representing - number of valid frames for i-th batch element in utterance. right_context (torch.Tensor): Right context frames, with shape (R, B, D). memory (torch.Tensor): Memory elements, with shape (M, B, D). - state (List[torch.Tensor], optional): - List of tensors representing layer internal state generated in - preceding computation. (default=None) + attn_cache (List[torch.Tensor]): + Cached attention tensors generated in preceding computation, + including memory, key and value of left context. conv_cache (torch.Tensor, optional): Cache tensor of left context for causal convolution. + padding_mask (torch.Tensor): + Padding mask of ker tensor. Returns: (Tensor, Tensor, List[torch.Tensor], Tensor): @@ -1109,8 +1036,10 @@ class EmformerEncoderLayer(nn.Module): ( src_att, output_memory, - output_state, - ) = self._apply_attention_module_infer(src, R, lengths, memory, state) + attn_cache, + ) = self._apply_attention_module_infer( + src, R, memory, attn_cache, padding_mask=padding_mask + ) src = src + self.dropout(src_att) # convolution module @@ -1128,7 +1057,7 @@ class EmformerEncoderLayer(nn.Module): output_utterance, output_right_context, output_memory, - output_state, + attn_cache, conv_cache, ) @@ -1179,8 +1108,8 @@ class EmformerEncoder(nn.Module): Length of left context. (default: 0) right_context_length (int, optional): Length of right context. (default: 0) - max_memory_size (int, optional): - Maximum number of memory elements to use. (default: 0) + memory_size (int, optional): + Number of memory elements to use. (default: 0) tanh_on_mem (bool, optional): If ``true``, applies tanh to memory elements. (default: ``false``) negative_inf (float, optional): @@ -1199,13 +1128,13 @@ class EmformerEncoder(nn.Module): cnn_module_kernel: int = 31, left_context_length: int = 0, right_context_length: int = 0, - max_memory_size: int = 0, + memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, ): super().__init__() - self.use_memory = max_memory_size > 0 + self.use_memory = memory_size > 0 self.init_memory_op = nn.AvgPool1d( kernel_size=chunk_length, stride=chunk_length, @@ -1224,7 +1153,7 @@ class EmformerEncoder(nn.Module): cnn_module_kernel=cnn_module_kernel, left_context_length=left_context_length, right_context_length=right_context_length, - max_memory_size=max_memory_size, + memory_size=memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, ) @@ -1232,10 +1161,13 @@ class EmformerEncoder(nn.Module): ] ) + self.num_encoder_layers = num_encoder_layers + self.d_model = d_model self.left_context_length = left_context_length self.right_context_length = right_context_length self.chunk_length = chunk_length - self.max_memory_size = max_memory_size + self.memory_size = memory_size + self.cnn_module_kernel = cnn_module_kernel def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" @@ -1276,7 +1208,7 @@ class EmformerEncoder(nn.Module): R = rc * num_chunks if self.use_memory: - m_start = max(chunk_idx - self.max_memory_size, 0) + m_start = max(chunk_idx - self.memory_size, 0) M = num_chunks - 1 col_widths = [ m_start, # before memory @@ -1430,15 +1362,18 @@ class EmformerEncoder(nn.Module): if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) + padding_mask = make_pad_mask( + memory.size(0) + right_context.size(0) + output_lengths + ) output = utterance for layer in self.emformer_layers: output, right_context, memory = layer( output, - output_lengths, right_context, memory, attention_mask, + padding_mask=padding_mask, warmup=warmup, ) @@ -1448,10 +1383,13 @@ class EmformerEncoder(nn.Module): self, x: torch.Tensor, lengths: torch.Tensor, - states: Optional[List[List[torch.Tensor]]] = None, - conv_caches: Optional[List[torch.Tensor]] = None, + states: List[ + torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor] + ], ) -> Tuple[ - torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor] + torch.Tensor, + torch.Tensor, + List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]], ]: """Forward pass for streaming inference. @@ -1467,13 +1405,13 @@ class EmformerEncoder(nn.Module): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, which contains the right_context at the end. - states (List[List[torch.Tensor]], optional): - Cached states from preceding chunk's computation, where each - element (List[torch.Tensor]) corresponds to each emformer layer. - (default: None) - conv_caches (List[torch.Tensor], optional): - Cached tensors of left context for causal convolution, where each - element (Tensor) corresponds to each convolutional layer. + states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa + Cached states containing: + - past_lens: number of past frames for each sample in batch + - attn_caches: attention states from preceding chunk's computation, + where each element corresponds to each emformer layer + - conv_caches: left context for causal convolution, where each + element corresponds to each layer. Returns: (Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]): @@ -1481,8 +1419,38 @@ class EmformerEncoder(nn.Module): - output lengths, with shape (B,), without containing the right_context at the end. - updated states from current chunk's computation. - - updated convolution caches from current chunk. """ + past_lens = states[0] + assert past_lens.shape == (x.size(1),), past_lens.shape + + attn_caches = states[1] + assert len(attn_caches) == self.num_encoder_layers, len(attn_caches) + for i in range(len(attn_caches)): + assert attn_caches[i][0].shape == ( + self.memory_size, + x.size(1), + self.d_model, + ), attn_caches[i][0].shape + assert attn_caches[i][1].shape == ( + self.left_context_length, + x.size(1), + self.d_model, + ), attn_caches[i][1].shape + assert attn_caches[i][2].shape == ( + self.left_context_length, + x.size(1), + self.d_model, + ), attn_caches[i][2].shape + + conv_caches = states[2] + assert len(conv_caches) == self.num_encoder_layers, len(conv_caches) + for i in range(len(conv_caches)): + assert conv_caches[i].shape == ( + x.size(1), + self.d_model, + self.cnn_module_kernel, + ), conv_caches[i].shape + assert x.size(0) == self.chunk_length + self.right_context_length, ( "Per configured chunk_length and right_context_length, " f"expected size of {self.chunk_length + self.right_context_length} " @@ -1498,28 +1466,60 @@ class EmformerEncoder(nn.Module): if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) + + # calcualte padding mask + chunk_mask = make_pad_mask(output_lengths) + memory_mask = ( + (past_lens // self.chunk_length).view(x.size(1), 1) + <= torch.arange(self.memory_size, device=x.device).expand( + x.size(1), self.memory_size + ) + ).flip(1) + left_context_mask = ( + past_lens.view(x.size(1), 1) + <= torch.arange(self.left_context_length, device=x.device).expand( + x.size(1), self.left_context_length + ) + ).flip(1) + right_context_mask = torch.zeros( + x.size(1), + self.right_context_length, + dtype=torch.bool, + device=x.device, + ) + padding_mask = torch.cat( + [memory_mask, left_context_mask, right_context_mask, chunk_mask], + dim=1, + ) + output = utterance - output_states: List[List[torch.Tensor]] = [] + output_attn_caches: List[List[torch.Tensor]] = [] output_conv_caches: List[torch.Tensor] = [] for layer_idx, layer in enumerate(self.emformer_layers): ( output, right_context, memory, - output_state, + output_attn_cache, output_conv_cache, ) = layer.infer( output, - output_lengths, right_context, memory, - None if states is None else states[layer_idx], - None if conv_caches is None else conv_caches[layer_idx], + padding_mask=padding_mask, + attn_cache=attn_caches[layer_idx], + conv_cache=conv_caches[layer_idx], ) - output_states.append(output_state) + output_attn_caches.append(output_attn_cache) output_conv_caches.append(output_conv_cache) - return output, output_lengths, output_states, output_conv_caches + output_past_lens = past_lens + output_lengths + output_states = [ + output_past_lens, + output_attn_caches, + output_conv_caches, + ] + return output, output_lengths, output_states class Emformer(EncoderInterface): @@ -1537,7 +1537,7 @@ class Emformer(EncoderInterface): cnn_module_kernel: int = 3, left_context_length: int = 0, right_context_length: int = 0, - max_memory_size: int = 0, + memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, ): @@ -1576,7 +1576,7 @@ class Emformer(EncoderInterface): cnn_module_kernel=cnn_module_kernel, left_context_length=left_context_length // 4, right_context_length=right_context_length // 4, - max_memory_size=max_memory_size, + memory_size=memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, ) @@ -1633,7 +1633,6 @@ class Emformer(EncoderInterface): x: torch.Tensor, x_lens: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None, - conv_caches: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: """Forward pass for streaming inference. @@ -1649,13 +1648,13 @@ class Emformer(EncoderInterface): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, containing the right_context at the end. - states (List[List[torch.Tensor]], optional): - Cached states from preceding chunk's computation, where each - element (List[torch.Tensor]) corresponds to each emformer layer. - (default: None) - conv_caches (List[torch.Tensor], optional): - Cached tensors of left context for causal convolution, where each - element (Tensor) corresponds to each convolutional layer. + states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa + Cached states containing: + - past_lens: number of past frames for each sample in batch + - attn_caches: attention states from preceding chunk's computation, + where each element corresponds to each emformer layer + - conv_caches: left context for causal convolution, where each + element corresponds to each layer. Returns: (Tensor, Tensor): - output embedding, with shape (B, T', D), where @@ -1663,7 +1662,6 @@ class Emformer(EncoderInterface): - output lengths, with shape (B,), without containing the right_context at the end. - updated states from current chunk's computation. - - updated convolution caches from current chunk. """ x = self.encoder_embed(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) @@ -1674,16 +1672,13 @@ class Emformer(EncoderInterface): x_lens = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == x_lens.max().item() - ( - output, - output_lengths, - output_states, - output_conv_caches, - ) = self.encoder.infer(x, x_lens, states, conv_caches) + output, output_lengths, output_states = self.encoder.infer( + x, x_lens, states + ) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - return output, output_lengths, output_states, output_conv_caches + return output, output_lengths, output_states class Conv2dSubsampling(nn.Module):