diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index d27570ff6..158b6f0fb 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -450,6 +450,10 @@ class EmformerAttention(nn.Module): Embedding dimension. nhead (int): Number of attention heads in each Emformer layer. + chunk_length (int): + Length of each input chunk. + right_context_length (int): + Length of right context. dropout (float, optional): Dropout probability. (Default: 0.0) tanh_on_mem (bool, optional): @@ -462,6 +466,8 @@ class EmformerAttention(nn.Module): self, embed_dim: int, nhead: int, + chunk_length: int, + right_context_length: int, dropout: float = 0.0, tanh_on_mem: bool = False, negative_inf: float = -1e8, @@ -479,6 +485,8 @@ class EmformerAttention(nn.Module): self.tanh_on_mem = tanh_on_mem self.negative_inf = negative_inf self.head_dim = embed_dim // nhead + self.chunk_length = chunk_length + self.right_context_length = right_context_length self.dropout = dropout self.emb_to_key_value = ScaledLinear( @@ -572,13 +580,16 @@ class EmformerAttention(nn.Module): Args: x: Input tensor, of shape (B, nhead, U, PE). U is the length of query vector. - For training and validation mode, PE = 2 * U - 1; - for inference mode, PE = L + 2 * U - 1. + For training and validation mode, + PE = 2 * U + right_context_length - 1. + For inference mode, + PE = tot_left_length + 2 * U + right_context_length - 1, + where tot_left_length = M * chunk_length. Returns: A tensor of shape (B, nhead, U, out_len). - For non-infer mode, out_len = U; - for infer mode, out_len = L + U. + For training and validation mode, out_len = U + right_context_length. + For inference mode, out_len = tot_left_length + U + right_context_length. # noqa """ B, nhead, U, PE = x.size() B_stride = x.stride(0) @@ -592,6 +603,33 @@ class EmformerAttention(nn.Module): storage_offset=PE_stride * (U - 1), ) + def _get_right_context_part( + self, matrix_bd_utterance: torch.Tensor + ) -> torch.Tensor: + """ + Args: + matrix_bd_utterance: + (B * nhead, U, U + right_context_length) + + Returns: + A tensor of shape (B * nhead, U, R), + where R = num_chunks * right_context_length. + """ + assert self.right_context_length > 0 + U = matrix_bd_utterance.size(1) + num_chunks = math.ceil(U / self.chunk_length) + right_context_blocks = [] + for i in range(num_chunks - 1): + start_idx = (i + 1) * self.chunk_length + end_idx = start_idx + self.right_context_length + right_context_blocks.append( + matrix_bd_utterance[:, :, start_idx:end_idx] + ) + right_context_blocks.append( + matrix_bd_utterance[:, :, -self.right_context_length :] + ) + return torch.cat(right_context_blocks, dim=2) + def _forward_impl( self, utterance: torch.Tensor, @@ -603,7 +641,15 @@ class EmformerAttention(nn.Module): pos_emb: torch.Tensor, left_context_key: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + need_weights=False, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: """Underlying chunk-wise attention implementation.""" U, B, _ = utterance.size() R = right_context.size(0) @@ -664,10 +710,12 @@ class EmformerAttention(nn.Module): if left_context_key is not None and left_context_val is not None: # inference mode L = left_context_key.size(0) - assert PE == L + 2 * U - 1 + tot_left_length = M * self.chunk_length if M > 0 else L + assert tot_left_length >= L + assert PE == tot_left_length + 2 * U + self.right_context_length - 1 else: # training and validation mode - assert PE == 2 * U - 1 + assert PE == 2 * U + self.right_context_length - 1 pos_emb = ( self.linear_pos(pos_emb) .view(PE, self.nhead, self.head_dim) @@ -679,13 +727,49 @@ class EmformerAttention(nn.Module): ) # (B, nhead, U, PE) # rel-shift operation matrix_bd_utterance = self._rel_shift(matrix_bd_utterance) - # (B, nhead, U, U) for training and validation mode; - # (B, nhead, U, L + U) for inference mode. + # (B, nhead, U, U + right_context_length) for training and validation mode; # noqa + # (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa matrix_bd_utterance = matrix_bd_utterance.contiguous().view( B * self.nhead, U, -1 ) matrix_bd = torch.zeros_like(matrix_ac) - matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance + if left_context_key is not None and left_context_val is not None: + # inference mode + # key: [memory, right context, left context, utterance] + # for memory + if M > 0: + matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d( + matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1), + kernel_size=(1, self.chunk_length), + stride=(1, self.chunk_length), + ).squeeze(1) + # for right_context + if R > 0: + matrix_bd[:, R : R + U, M : M + R] = matrix_bd_utterance[ + :, :, tot_left_length + U : + ] + # for left_context and utterance + matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[ + :, :, tot_left_length - L : tot_left_length + U + ] + else: + # training and validation mode + # key: [memory, right context, utterance] + # for memory + if M > 0: + matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d( + matrix_bd_utterance[:, :, :U].unsqueeze(1), + kernel_size=(1, self.chunk_length), + stride=(1, self.chunk_length), + ceil_mode=True, + ).squeeze(1)[:, :, :-1] + # for right_context + if R > 0: + matrix_bd[ + :, R : R + U, M : M + R + ] = self._get_right_context_part(matrix_bd_utterance) + # for utterance + matrix_bd[:, R : R + U, M + R :] = matrix_bd_utterance[:, :, :U] attention_weights = matrix_ac + matrix_bd @@ -717,7 +801,29 @@ class EmformerAttention(nn.Module): else: output_memory = torch.clamp(output_memory, min=-10, max=10) - return output_right_context_utterance, output_memory, key, value + if need_weights: + # average over attention heads + attention_probs = attention_probs.reshape(B, self.nhead, Q, KV) + attention_probs = attention_probs.sum(dim=1) / self.nhead + probs_memory = attention_probs[:, R : R + U, :M].sum(dim=2) + probs_frames = attention_probs[:, R : R + U, M:].sum(dim=2) + return ( + output_right_context_utterance, + output_memory, + key, + value, + probs_memory, + probs_frames, + ) + + return ( + output_right_context_utterance, + output_memory, + key, + value, + None, + None, + ) def forward( self, @@ -766,10 +872,11 @@ class EmformerAttention(nn.Module): s_i (in query) -> l_i, c_i, r_i (in key). Relative positional encoding is applied on the part of attention between - utterance (in query) and utterance (in key). Actually, it is applied on - the part of attention between each chunk (in query) and itself as well - as its left context (in key), after applying the mask: - c_i -> l_i, c_i. + [utterance] (in query) and [memory, right_context, utterance] (in key). + Actually, it is applied on the part of attention between each chunk + (in query) and itself, its memory vectors, left context, and right + context (in key), after applying the mask: + c_i (in query) -> l_i, c_i, r_i, m_i (in key). Args: utterance (torch.Tensor): @@ -791,18 +898,23 @@ class EmformerAttention(nn.Module): attention, with shape (Q, KV). pos_emb (torch.Tensor): Position encoding embedding, with shape (PE, D). - where PE = 2 * U - 1. + where PE = 2 * U + right_context_length - 1. Returns: A tuple containing 2 tensors: - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. + - summary of attention weights on memory, with shape (B, U). + - summary of attention weights on left context, utterance, and + right context, with shape (B, U). """ ( output_right_context_utterance, output_memory, _, _, + probs_memory, + probs_frames, ) = self._forward_impl( utterance, lengths, @@ -811,8 +923,14 @@ class EmformerAttention(nn.Module): memory, attention_mask, pos_emb, + need_weights=True, + ) + return ( + output_right_context_utterance, + output_memory[:-1], + probs_memory, + probs_frames, ) - return output_right_context_utterance, output_memory[:-1] def infer( self, @@ -849,9 +967,9 @@ class EmformerAttention(nn.Module): left context, chunk, right context, memory vectors (in key); summary (in query) -> left context, chunk, right context (in key). - Relative positional encoding is applied on the part of attention between - chunk (in query) and chunk itself as well as its left context (in key): - chunk (in query) -> left context, chunk (in key). + Relative positional encoding is applied on the part of attention: + chunk (in query) -> + left context, chunk, right context, memory vectors (in key); Args: utterance (torch.Tensor): @@ -874,7 +992,7 @@ class EmformerAttention(nn.Module): with shape (L, B, D), where L <= left_context_length. pos_emb (torch.Tensor): Position encoding embedding, with shape (PE, D), - where PE = L + 2 * U - 1. + where PE = M * chunk_length + 2 * U - 1 if M > 0 else L + 2 * U - 1. Returns: A tuple containing 4 tensors: @@ -905,6 +1023,8 @@ class EmformerAttention(nn.Module): output_memory, key, value, + _, + _, ) = self._forward_impl( utterance, lengths, @@ -974,6 +1094,8 @@ class EmformerEncoderLayer(nn.Module): self.attention = EmformerAttention( embed_dim=d_model, nhead=nhead, + chunk_length=chunk_length, + right_context_length=right_context_length, dropout=dropout, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, @@ -1018,6 +1140,7 @@ class EmformerEncoderLayer(nn.Module): self.layer_dropout = layer_dropout self.left_context_length = left_context_length self.chunk_length = chunk_length + self.right_context_length = right_context_length self.max_memory_size = max_memory_size self.d_model = d_model self.use_memory = max_memory_size > 0 @@ -1140,7 +1263,7 @@ class EmformerEncoderLayer(nn.Module): summary = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device ) - output_right_context_utterance, output_memory = self.attention( + output_right_context_utterance, output_memory, _, _ = self.attention( utterance=utterance, lengths=lengths, right_context=right_context, @@ -1190,14 +1313,26 @@ class EmformerEncoderLayer(nn.Module): summary = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device ) - # pos_emb is of shape [PE, D], where PE = L + 2 * U - 1, - # for query of [utterance] (i), key-value [left_context, utterance] (j), - # the max relative distance i - j is L + U - 1 - # the min relative distance i - j is -(U - 1) - L = left_context_key.size(0) # L <= left_context_length U = utterance.size(0) - PE = L + 2 * U - 1 - tot_PE = self.left_context_length + 2 * U - 1 + # pos_emb is of shape [PE, D], where PE = M * chunk_length + 2 * U - 1, + # for query of [utterance] (i), key-value [memory vectors, left context, utterance, right context] (j) # noqa + # the max relative distance i - j is M * chunk_length + U - 1 + # the min relative distance i - j is -(U + right_context_length - 1) + M = pre_memory.size(0) # M <= max_memory_size + if self.max_memory_size > 0: + PE = M * self.chunk_length + 2 * U + self.right_context_length - 1 + tot_PE = ( + self.max_memory_size * self.chunk_length + + 2 * U + + self.right_context_length + - 1 + ) + else: + L = left_context_key.size(0) + PE = L + 2 * U + self.right_context_length - 1 + tot_PE = ( + self.left_context_length + 2 * U + self.right_context_length - 1 + ) assert pos_emb.size(0) == tot_PE pos_emb = pos_emb[tot_PE - PE :] ( @@ -1661,7 +1796,9 @@ class EmformerEncoder(nn.Module): right_context at the end. """ U = x.size(0) - self.right_context_length - x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U) + x, pos_emb = self.encoder_pos( + x, pos_len=U, neg_len=U + self.right_context_length + ) right_context = self._gen_right_context(x) utterance = x[:U] @@ -1734,8 +1871,12 @@ class EmformerEncoder(nn.Module): f"for dimension 1 of x, but got {x.size(1)}." ) - pos_len = self.chunk_length + self.left_context_length - neg_len = self.chunk_length + pos_len = ( + self.max_memory_size * self.chunk_length + self.chunk_length + if self.max_memory_size > 0 + else self.left_context_length + self.chunk_length + ) + neg_len = self.chunk_length + self.right_context_length x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len) right_context_start_idx = x.size(0) - self.right_context_length @@ -1807,6 +1948,13 @@ class Emformer(EncoderInterface): raise NotImplementedError( "right_context_length must be 0 or a mutiple of 4." ) + if ( + max_memory_size > 0 + and max_memory_size * chunk_length < left_context_length + ): + raise NotImplementedError( + "max_memory_size * chunk_length can not be less than left_context_length" # noqa + ) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py index f0a543327..57ad3b4ec 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -22,7 +22,12 @@ def test_emformer_attention_forward(): num_chunks = 3 U = num_chunks * chunk_length R = num_chunks * right_context_length - attention = EmformerAttention(embed_dim=D, nhead=8) + attention = EmformerAttention( + embed_dim=D, + nhead=8, + chunk_length=chunk_length, + right_context_length=right_context_length, + ) for use_memory in [True, False]: if use_memory: @@ -39,10 +44,15 @@ def test_emformer_attention_forward(): summary = torch.randn(S, B, D) memory = torch.randn(M, B, D) attention_mask = torch.rand(Q, KV) >= 0.5 - PE = 2 * U - 1 + PE = 2 * U + right_context_length - 1 pos_emb = torch.randn(PE, D) - output_right_context_utterance, output_memory = attention( + ( + output_right_context_utterance, + output_memory, + probs_memory, + probs_frames, + ) = attention( utterance, lengths, right_context, @@ -53,16 +63,26 @@ def test_emformer_attention_forward(): ) assert output_right_context_utterance.shape == (R + U, B, D) assert output_memory.shape == (M, B, D) + assert probs_memory.shape == (B, U) + assert probs_frames.shape == (B, U) def test_emformer_attention_infer(): from emformer import EmformerAttention B, D = 2, 256 - U = 4 - R = 2 + chunk_length = 4 + right_context_length = 2 + num_chunks = 1 + U = chunk_length * num_chunks + R = right_context_length * num_chunks L = 3 - attention = EmformerAttention(embed_dim=D, nhead=8) + attention = EmformerAttention( + embed_dim=D, + nhead=8, + chunk_length=chunk_length, + right_context_length=right_context_length, + ) for use_memory in [True, False]: if use_memory: @@ -78,7 +98,12 @@ def test_emformer_attention_infer(): memory = torch.randn(M, B, D) left_context_key = torch.randn(L, B, D) left_context_val = torch.randn(L, B, D) - PE = L + 2 * U - 1 + PE = ( + 2 * U + + right_context_length + - 1 + + (M * chunk_length if M > 0 else L) + ) pos_emb = torch.randn(PE, D) ( @@ -197,7 +222,7 @@ def test_emformer_encoder_layer_forward(): right_context = torch.randn(R, B, D) memory = torch.randn(M, B, D) attention_mask = torch.rand(Q, KV) >= 0.5 - PE = 2 * U - 1 + PE = 2 * U + right_context_length - 1 pos_emb = torch.randn(PE, D) output_utterance, output_right_context, output_memory = layer( @@ -227,8 +252,10 @@ def test_emformer_encoder_layer_infer(): for use_memory in [True, False]: if use_memory: - M = 3 + max_memory_size = 3 + M = 1 else: + max_memory_size = 0 M = 0 layer = EmformerEncoderLayer( @@ -239,7 +266,7 @@ def test_emformer_encoder_layer_infer(): cnn_module_kernel=kernel_size, left_context_length=left_context_length, right_context_length=right_context_length, - max_memory_size=M, + max_memory_size=max_memory_size, ) utterance = torch.randn(U, B, D) @@ -248,7 +275,16 @@ def test_emformer_encoder_layer_infer(): right_context = torch.randn(R, B, D) memory = torch.randn(M, B, D) state = None - PE = left_context_length + 2 * U - 1 + PE = ( + 2 * U + + right_context_length + - 1 + + ( + max_memory_size * chunk_length + if max_memory_size > 0 + else left_context_length + ) + ) pos_emb = torch.randn(PE, D) conv_cache = None ( @@ -273,7 +309,7 @@ def test_emformer_encoder_layer_infer(): else: assert output_memory.shape == (0, B, D) assert len(output_state) == 4 - assert output_state[0].shape == (M, B, D) + assert output_state[0].shape == (max_memory_size, B, D) assert output_state[1].shape == (left_context_length, B, D) assert output_state[2].shape == (left_context_length, B, D) assert output_state[3].shape == (1, B) @@ -334,9 +370,9 @@ def test_emformer_encoder_infer(): for use_memory in [True, False]: if use_memory: - M = 3 + max_memory_size = 3 else: - M = 0 + max_memory_size = 0 encoder = EmformerEncoder( chunk_length=chunk_length, @@ -346,7 +382,7 @@ def test_emformer_encoder_infer(): cnn_module_kernel=kernel_size, left_context_length=left_context_length, right_context_length=right_context_length, - max_memory_size=M, + max_memory_size=max_memory_size, ) states = None @@ -368,7 +404,7 @@ def test_emformer_encoder_infer(): assert len(states) == num_encoder_layers for state in states: assert len(state) == 4 - assert state[0].shape == (M, B, D) + assert state[0].shape == (max_memory_size, B, D) assert state[1].shape == (left_context_length, B, D) assert state[2].shape == (left_context_length, B, D) assert torch.equal( @@ -391,7 +427,7 @@ def test_emformer_encoder_forward_infer_consistency(): kernel_size = 31 memory_sizes = [0, 3] - for M in memory_sizes: + for max_memory_size in memory_sizes: encoder = EmformerEncoder( chunk_length=chunk_length, d_model=D, @@ -400,7 +436,7 @@ def test_emformer_encoder_forward_infer_consistency(): cnn_module_kernel=kernel_size, left_context_length=left_context_length, right_context_length=right_context_length, - max_memory_size=M, + max_memory_size=max_memory_size, ) encoder.eval() @@ -449,9 +485,9 @@ def test_emformer_forward(): for use_memory in [True, False]: if use_memory: - M = 3 + max_memory_size = 3 else: - M = 0 + max_memory_size = 0 model = Emformer( num_features=num_features, chunk_length=chunk_length, @@ -460,7 +496,7 @@ def test_emformer_forward(): cnn_module_kernel=kernel_size, left_context_length=left_context_length, right_context_length=right_context_length, - max_memory_size=M, + max_memory_size=max_memory_size, ) x = torch.randn(B, U + right_context_length + 3, num_features) x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,)) @@ -481,7 +517,7 @@ def test_emformer_infer(): num_features = 80 chunk_length = 8 U = chunk_length - left_context_length, right_context_length = 128, 4 + left_context_length, right_context_length = 32, 4 B, D = 2, 256 num_chunks = 3 num_encoder_layers = 2 @@ -489,9 +525,9 @@ def test_emformer_infer(): for use_memory in [True, False]: if use_memory: - M = 3 + max_memory_size = 32 else: - M = 0 + max_memory_size = 0 model = Emformer( num_features=num_features, chunk_length=chunk_length, @@ -501,7 +537,7 @@ def test_emformer_infer(): cnn_module_kernel=kernel_size, left_context_length=left_context_length, right_context_length=right_context_length, - max_memory_size=M, + max_memory_size=max_memory_size, ) states = None conv_caches = None @@ -523,7 +559,7 @@ def test_emformer_infer(): assert len(states) == num_encoder_layers for state in states: assert len(state) == 4 - assert state[0].shape == (M, B, D) + assert state[0].shape == (max_memory_size, B, D) assert state[1].shape == (left_context_length // 4, B, D) assert state[2].shape == (left_context_length // 4, B, D) assert torch.equal(