From 193b44ed7aa7649546c49806a4fc4d59b3b0ddfe Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 13 Jun 2022 22:14:24 +0800 Subject: [PATCH] use average value as memory vector for each chunk --- .../emformer.py | 208 +++++------------- 1 file changed, 53 insertions(+), 155 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 46993da48..c5d862ad8 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -537,7 +537,6 @@ class EmformerAttention(nn.Module): self, utterance: torch.Tensor, right_context: torch.Tensor, - summary: torch.Tensor, memory: torch.Tensor, attention_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, @@ -550,10 +549,8 @@ class EmformerAttention(nn.Module): M = memory.size(0) scaling = float(self.head_dim) ** -0.5 - # compute query with [right_context, utterance, summary]. - query = self.emb_to_query( - torch.cat([right_context, utterance, summary]) - ) + # compute query with [right_context, utterance]. + query = self.emb_to_query(torch.cat([right_context, utterance])) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -593,26 +590,18 @@ class EmformerAttention(nn.Module): ) # apply output projection - outputs = self.out_proj(attention) + output_right_context_utterance = self.out_proj(attention) - output_right_context_utterance = outputs[: R + U] - output_memory = outputs[R + U :] - if self.tanh_on_mem: - output_memory = torch.tanh(output_memory) - else: - output_memory = torch.clamp(output_memory, min=-10, max=10) - - return output_right_context_utterance, output_memory, key, value + return output_right_context_utterance, key, value def forward( self, utterance: torch.Tensor, right_context: torch.Tensor, - summary: torch.Tensor, memory: torch.Tensor, attention_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: # TODO: Modify docs. """Forward pass for training and validation mode. @@ -620,17 +609,16 @@ class EmformerAttention(nn.Module): D: embedding dimension; R: length of the hard-copied right contexts; U: length of full utterance; - S: length of summary vectors; M: length of memory vectors. It computes a `big` attention matrix on full utterance and then utilizes a pre-computed mask to simulate chunk-wise attention. It concatenates three blocks: hard-copied right contexts, - full utterance, and summary vectors, as a `big` block, + and full utterance, as a `big` block, to compute the query tensor: - query = [right_context, utterance, summary], - with length Q = R + U + S. + query = [right_context, utterance], + with length Q = R + U. It concatenates the three blocks: memory vectors, hard-copied right contexts, and full utterance as another `big` block, to compute the key and value tensors: @@ -644,10 +632,8 @@ class EmformerAttention(nn.Module): r_i: right context that c_i can use; l_i: left context that c_i can use; m_i: past memory vectors from previous layer that c_i can use; - s_i: summary vector of c_i; The target chunk-wise attention is: - c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key); - s_i (in query) -> l_i, c_i, r_i (in key). + c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key) Args: utterance (torch.Tensor): @@ -655,9 +641,6 @@ class EmformerAttention(nn.Module): right_context (torch.Tensor): Hard-copied right context frames, with shape (R, B, D), where R = num_chunks * right_context_length - summary (torch.Tensor): - Summary elements with shape (S, B, D), where S = num_chunks. - It is an empty tensor without using memory. memory (torch.Tensor): Memory elements, with shape (M, B, D), where M = num_chunks - 1. It is an empty tensor without using memory. @@ -668,31 +651,22 @@ class EmformerAttention(nn.Module): Padding mask of key tensor, with shape (B, KV). Returns: - A tuple containing 2 tensors: - - output of right context and utterance, with shape (R + U, B, D). - - memory output, with shape (M, B, D), where M = S - 1 or M = 0. + Output of right context and utterance, with shape (R + U, B, D). """ - ( - output_right_context_utterance, - output_memory, - _, - _, - ) = self._forward_impl( + output_right_context_utterance, _, _ = self._forward_impl( utterance, right_context, - summary, memory, attention_mask, padding_mask=padding_mask, ) - return output_right_context_utterance, output_memory[:-1] + return output_right_context_utterance @torch.jit.export def infer( self, utterance: torch.Tensor, right_context: torch.Tensor, - summary: torch.Tensor, memory: torch.Tensor, left_context_key: torch.Tensor, left_context_val: torch.Tensor, @@ -705,13 +679,12 @@ class EmformerAttention(nn.Module): R: length of right context; U: length of utterance, i.e., current chunk; L: length of cached left context; - S: length of summary vectors, S = 1; M: length of cached memory vectors. - It concatenates the right context, utterance (i.e., current chunk) - and summary vector of current chunk, to compute the query tensor: - query = [right_context, utterance, summary], - with length Q = R + U + S. + It concatenates the right context and utterance (i.e., current chunk) + of current chunk, to compute the query tensor: + query = [right_context, utterance], + with length Q = R + U. It concatenates the memory vectors, right context, left context, and current chunk, to compute the key and value tensors: key & value = [memory, right_context, left_context, utterance], @@ -719,8 +692,7 @@ class EmformerAttention(nn.Module): The chunk-wise attention is: chunk, right context (in query) -> - left context, chunk, right context, memory vectors (in key); - summary (in query) -> left context, chunk, right context (in key). + left context, chunk, right context, memory vectors (in key). Args: utterance (torch.Tensor): @@ -728,8 +700,6 @@ class EmformerAttention(nn.Module): right_context (torch.Tensor): Right context frames, with shape (R, B, D), where R = right_context_length. - summary (torch.Tensor): - Summary vector with shape (1, B, D), or empty tensor. memory (torch.Tensor): Memory vectors, with shape (M, B, D), or empty tensor. left_context_key (torch,Tensor): @@ -744,7 +714,6 @@ class EmformerAttention(nn.Module): Returns: A tuple containing 4 tensors: - output of right context and utterance, with shape (R + U, B, D). - - memory output, with shape (1, B, D) or (0, B, D). - attention key of left context and utterance, which would be cached for next computation, with shape (L + U, B, D). - attention value of left context and utterance, which would be @@ -753,28 +722,19 @@ class EmformerAttention(nn.Module): U = utterance.size(0) R = right_context.size(0) L = left_context_key.size(0) - S = summary.size(0) M = memory.size(0) - # TODO: move it outside - # query = [right context, utterance, summary] - Q = R + U + S + # query = [right context, utterance] + Q = R + U # key, value = [memory, right context, left context, uttrance] KV = M + R + L + U attention_mask = torch.zeros(Q, KV).to( dtype=torch.bool, device=utterance.device ) - # disallow attention bettween the summary vector with the memory bank - attention_mask[-1, :M] = True - ( - output_right_context_utterance, - output_memory, - key, - value, - ) = self._forward_impl( + + output_right_context_utterance, key, value = self._forward_impl( utterance, right_context, - summary, memory, attention_mask, padding_mask=padding_mask, @@ -783,7 +743,6 @@ class EmformerAttention(nn.Module): ) return ( output_right_context_utterance, - output_memory, key[M + R :], value[M + R :], ) @@ -938,49 +897,46 @@ class EmformerEncoderLayer(nn.Module): self, right_context_utterance: torch.Tensor, R: int, - memory: torch.Tensor, attention_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """Apply attention module in training and validation mode.""" utterance = right_context_utterance[R:] right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( 2, 0, 1 - ) + )[:-1, :, :] else: - summary = torch.empty(0).to( + memory = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device ) - output_right_context_utterance, output_memory = self.attention( + output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, - summary=summary, memory=memory, attention_mask=attention_mask, padding_mask=padding_mask, ) - return output_right_context_utterance, output_memory + return output_right_context_utterance def _apply_attention_module_infer( self, right_context_utterance: torch.Tensor, R: int, - memory: torch.Tensor, attn_cache: List[torch.Tensor], padding_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Apply attention module in inference mode. 1) Unpack cached states including: - - memory from previous chunks in the lower layer; + - memory from previous chunks; - attention key and value of left context from preceding chunk's compuation; 2) Apply attention computation; 3) Update cached attention states including: - - output memory of current chunk in the lower layer; + - memory of current chunk; - attention key and value in current chunk's computation, which would be resued in next chunk's computation. """ @@ -992,23 +948,20 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( 2, 0, 1 - ) - summary = summary[:1] + )[:1, :, :] else: - summary = torch.empty(0).to( + memory = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device ) ( output_right_context_utterance, - output_memory, next_key, next_val, ) = self.attention.infer( utterance=utterance, right_context=right_context, - summary=summary, memory=pre_memory, left_context_key=left_context_key, left_context_val=left_context_val, @@ -1017,17 +970,16 @@ class EmformerEncoderLayer(nn.Module): attn_cache = self._update_attn_cache( next_key, next_val, memory, attn_cache ) - return output_right_context_utterance, output_memory, attn_cache + return output_right_context_utterance, attn_cache def forward( self, utterance: torch.Tensor, right_context: torch.Tensor, - memory: torch.Tensor, attention_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, warmup: float = 1.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Forward pass for training and validation mode. B: batch size; @@ -1041,20 +993,16 @@ class EmformerEncoderLayer(nn.Module): Utterance frames, with shape (U, B, D). right_context (torch.Tensor): Right context frames, with shape (R, B, D). - memory (torch.Tensor): - Memory elements, with shape (M, B, D). - It is an empty tensor without using memory. attention_mask (torch.Tensor): Attention mask for underlying attention module, - with shape (Q, KV), where Q = R + U + S, KV = M + R + U. + with shape (Q, KV), where Q = R + U, KV = M + R + U. padding_mask (torch.Tensor): Padding mask of ker tensor, with shape (B, KV). Returns: - A tuple containing 3 tensors: + A tuple containing 2 tensors: - output utterance, with shape (U, B, D). - output right context, with shape (R, B, D). - - output memory, with shape (M, B, D). """ R = right_context.size(0) src = torch.cat([right_context, utterance]) @@ -1076,8 +1024,8 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - src_att, output_memory = self._apply_attention_module_forward( - src, R, memory, attention_mask, padding_mask=padding_mask + src_att = self._apply_attention_module_forward( + src, R, attention_mask, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1095,24 +1043,17 @@ class EmformerEncoderLayer(nn.Module): output_utterance = src[R:] output_right_context = src[:R] - return output_utterance, output_right_context, output_memory + return output_utterance, output_right_context @torch.jit.export def infer( self, utterance: torch.Tensor, right_context: torch.Tensor, - memory: torch.Tensor, attn_cache: List[torch.Tensor], conv_cache: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - List[torch.Tensor], - torch.Tensor, - ]: + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: """Forward pass for inference. B: batch size; @@ -1126,8 +1067,6 @@ class EmformerEncoderLayer(nn.Module): Utterance frames, with shape (U, B, D). right_context (torch.Tensor): Right context frames, with shape (R, B, D). - memory (torch.Tensor): - Memory elements, with shape (M, B, D). attn_cache (List[torch.Tensor]): Cached attention tensors generated in preceding computation, including memory, key and value of left context. @@ -1140,9 +1079,8 @@ class EmformerEncoderLayer(nn.Module): (Tensor, Tensor, List[torch.Tensor], Tensor): - output utterance, with shape (U, B, D); - output right_context, with shape (R, B, D); - - output memory, with shape (1, B, D) or (0, B, D). - - output state. - - updated conv_cache. + - output attention cache; + - output convolution cache. """ R = right_context.size(0) src = torch.cat([right_context, utterance]) @@ -1151,12 +1089,8 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - ( - src_att, - output_memory, - attn_cache, - ) = self._apply_attention_module_infer( - src, R, memory, attn_cache, padding_mask=padding_mask + src_att, attn_cache = self._apply_attention_module_infer( + src, R, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1174,7 +1108,6 @@ class EmformerEncoderLayer(nn.Module): return ( output_utterance, output_right_context, - output_memory, attn_cache, conv_cache, ) @@ -1253,11 +1186,6 @@ class EmformerEncoder(nn.Module): super().__init__() self.use_memory = memory_size > 0 - self.init_memory_op = nn.AvgPool1d( - kernel_size=chunk_length, - stride=chunk_length, - ceil_mode=True, - ) self.emformer_layers = nn.ModuleList( [ @@ -1358,16 +1286,15 @@ class EmformerEncoder(nn.Module): R: length of hard-copied right contexts; U: length of full utterance; - S: length of summary vectors; M: length of memory vectors; Q: length of attention query; KV: length of attention key and value. The shape of attention mask is (Q, KV). If self.use_memory is `True`: - query = [right_context, utterance, summary]; + query = [right_context, utterance]; key, value = [memory, right_context, utterance]; - Q = R + U + S, KV = M + R + U. + Q = R + U, KV = M + R + U. Otherwise: query = [right_context, utterance] key, value = [right_context, utterance] @@ -1378,17 +1305,14 @@ class EmformerEncoder(nn.Module): r_i: right context that c_i can use; l_i: left context that c_i can use; m_i: past memory vectors from previous layer that c_i can use; - s_i: summary vector of c_i. The target chunk-wise attention is: - c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key); - s_i (in query) -> l_i, c_i, r_i (in key). + c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key). """ U = utterance.size(0) num_chunks = math.ceil(U / self.chunk_length) right_context_mask = [] utterance_mask = [] - summary_mask = [] if self.use_memory: num_cols = 9 @@ -1397,9 +1321,6 @@ class EmformerEncoder(nn.Module): right_context_utterance_cols_mask = [ idx in [1, 4, 7] for idx in range(num_cols) ] - # summary attends to right context, utterance - summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)] - masks_to_concat = [right_context_mask, utterance_mask, summary_mask] else: num_cols = 6 # right context and utterance both attend to right context and @@ -1407,8 +1328,7 @@ class EmformerEncoder(nn.Module): right_context_utterance_cols_mask = [ idx in [1, 4] for idx in range(num_cols) ] - summary_cols_mask = None - masks_to_concat = [right_context_mask, utterance_mask] + masks_to_concat = [right_context_mask, utterance_mask] for chunk_idx in range(num_chunks): col_widths = self._gen_attention_mask_col_widths(chunk_idx, U) @@ -1432,12 +1352,6 @@ class EmformerEncoder(nn.Module): ) utterance_mask.append(utterance_mask_block) - if summary_cols_mask is not None: - summary_mask_block = _gen_attention_mask_block( - col_widths, summary_cols_mask, 1, utterance.device - ) - summary_mask.append(summary_mask_block) - attention_mask = ( 1 - torch.cat([torch.cat(mask) for mask in masks_to_concat]) ).to(torch.bool) @@ -1473,23 +1387,15 @@ class EmformerEncoder(nn.Module): utterance = x[:U] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) - memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1 - ] - if self.use_memory - else torch.empty(0).to(dtype=x.dtype, device=x.device) - ) - padding_mask = make_pad_mask( - memory.size(0) + right_context.size(0) + output_lengths - ) + + M = right_context.size(0) // self.chunk_length - 1 + padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths) output = utterance for layer in self.emformer_layers: - output, right_context, memory = layer( + output, right_context = layer( output, right_context, - memory, attention_mask, padding_mask=padding_mask, warmup=warmup, @@ -1525,7 +1431,6 @@ class EmformerEncoder(nn.Module): right_context at the end. states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa Cached states containing: - - past_lens: number of past frames for each sample in batch - attn_caches: attention states from preceding chunk's computation, where each element corresponds to each emformer layer - conv_caches: left context for causal convolution, where each @@ -1571,11 +1476,6 @@ class EmformerEncoder(nn.Module): right_context = x[-self.right_context_length :] utterance = x[: -self.right_context_length] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) - memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) - if self.use_memory - else torch.empty(0).to(dtype=x.dtype, device=x.device) - ) # calcualte padding mask to mask out initial zero caches chunk_mask = make_pad_mask(output_lengths).to(x.device) @@ -1611,13 +1511,11 @@ class EmformerEncoder(nn.Module): ( output, right_context, - memory, output_attn_cache, output_conv_cache, ) = layer.infer( output, right_context, - memory, padding_mask=padding_mask, attn_cache=attn_caches[layer_idx], conv_cache=conv_caches[layer_idx],