diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index f3ee7b0f7..1e551bc39 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -419,7 +419,7 @@ class ConvolutionModule(nn.Module): assert cache.shape == (B, D, self.cache_size), cache.shape x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R) # update cache - new_cache = x[:, :, -R - self.cache_size:-R] + new_cache = x[:, :, -R - self.cache_size : -R] # 1-D depth-wise conv x = self.depthwise_conv(x) # (B, D, U + R) @@ -572,7 +572,7 @@ class EmformerAttention(nn.Module): Args: x: Input tensor, of shape (B, nhead, U, PE). U is the length of query vector. - For training mode, PE = 2 * U - 1; + For training and validation mode, PE = 2 * U - 1; for inference mode, PE = L + 2 * U - 1. Returns: @@ -666,7 +666,7 @@ class EmformerAttention(nn.Module): L = left_context_key.size(0) assert PE == L + 2 * U - 1 else: - # training mode + # training and validation mode assert PE == 2 * U - 1 pos_emb = ( self.linear_pos(pos_emb) @@ -679,7 +679,7 @@ class EmformerAttention(nn.Module): ) # (B, nhead, U, PE) # rel-shift operation matrix_bd_utterance = self._rel_shift(matrix_bd_utterance) - # (B, nhead, U, U) for training mode; + # (B, nhead, U, U) for training and validation mode; # (B, nhead, U, L + U) for inference mode. matrix_bd_utterance = matrix_bd_utterance.contiguous().view( B * self.nhead, U, -1 @@ -730,7 +730,7 @@ class EmformerAttention(nn.Module): pos_emb: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: Modify docs. - """Forward pass for training mode. + """Forward pass for training and validation mode. B: batch size; D: embedding dimension; @@ -922,3 +922,464 @@ class EmformerAttention(nn.Module): key[M + R :], value[M + R :], ) + + +class EmformerEncoderLayer(nn.Module): + """Emformer layer that constitutes Emformer. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads. + dim_feedforward (int): + Hidden layer dimension of feedforward network. + chunk_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (Default: 0.0) + layer_dropout (float, optional): + Layer dropout probability. (Default: 0.0) + cnn_module_kernel (int): + Kernel size of convolution module. + left_context_length (int, optional): + Length of left context. (Default: 0) + right_context_length (int, optional): + Length of right context. (Default: 0) + max_memory_size (int, optional): + Maximum number of memory elements to use. (Default: 0) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int, + chunk_length: int, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.attention = EmformerAttention( + embed_dim=d_model, + nhead=nhead, + dropout=dropout, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + self.summary_op = nn.AvgPool1d( + kernel_size=chunk_length, stride=chunk_length, ceil_mode=True + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule( + chunk_length, + right_context_length, + d_model, + cnn_module_kernel, + ) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean + # (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + self.layer_dropout = layer_dropout + self.left_context_length = left_context_length + self.chunk_length = chunk_length + self.max_memory_size = max_memory_size + self.d_model = d_model + self.use_memory = max_memory_size > 0 + + def _init_state( + self, batch_size: int, device: Optional[torch.device] + ) -> List[torch.Tensor]: + """Initialize states with zeros.""" + empty_memory = torch.zeros( + self.max_memory_size, batch_size, self.d_model, device=device + ) + left_context_key = torch.zeros( + self.left_context_length, batch_size, self.d_model, device=device + ) + left_context_val = torch.zeros( + self.left_context_length, batch_size, self.d_model, device=device + ) + past_length = torch.zeros( + 1, batch_size, dtype=torch.int32, device=device + ) + return [empty_memory, left_context_key, left_context_val, past_length] + + def _unpack_state( + self, state: List[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Unpack cached states including: + 1) output memory from previous chunks in the lower layer; + 2) attention key and value of left context from proceeding chunk's + computation. + """ + past_length = state[3][0][0].item() + past_left_context_length = min(self.left_context_length, past_length) + past_memory_length = min( + self.max_memory_size, math.ceil(past_length / self.chunk_length) + ) + memory_start_idx = self.max_memory_size - past_memory_length + pre_memory = state[0][memory_start_idx:] + left_context_start_idx = ( + self.left_context_length - past_left_context_length + ) + left_context_key = state[1][left_context_start_idx:] + left_context_val = state[2][left_context_start_idx:] + return pre_memory, left_context_key, left_context_val + + def _pack_state( + self, + next_key: torch.Tensor, + next_val: torch.Tensor, + update_length: int, + memory: torch.Tensor, + state: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Pack updated states including: + 1) output memory of current chunk in the lower layer; + 2) attention key and value in current chunk's computation, which would + be resued in next chunk's computation. + 3) length of current chunk. + """ + new_memory = torch.cat([state[0], memory]) + new_key = torch.cat([state[1], next_key]) + new_val = torch.cat([state[2], next_val]) + memory_start_idx = new_memory.size(0) - self.max_memory_size + state[0] = new_memory[memory_start_idx:] + key_start_idx = new_key.size(0) - self.left_context_length + state[1] = new_key[key_start_idx:] + val_start_idx = new_val.size(0) - self.left_context_length + state[2] = new_val[val_start_idx:] + state[3] = state[3] + update_length + return state + + def _apply_conv_module_forward( + self, + right_context_utterance: torch.Tensor, + R: int, + ) -> torch.Tensor: + """Apply convolution module in training and validation mode.""" + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + utterance, right_context, _ = self.conv_module(utterance, right_context) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance + + def _apply_conv_module_infer( + self, + right_context_utterance: torch.Tensor, + R: int, + conv_cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply convolution module on utterance in inference mode.""" + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + utterance, right_context, conv_cache = self.conv_module.infer( + utterance, right_context, conv_cache + ) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance, conv_cache + + def _apply_attention_module_forward( + self, + right_context_utterance: torch.Tensor, + R: int, + lengths: torch.Tensor, + memory: torch.Tensor, + pos_emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply attention module in training and validation mode.""" + if attention_mask is None: + raise ValueError( + "attention_mask must be not None in training or validation mode." # noqa + ) + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + + if self.use_memory: + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) + else: + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + output_right_context_utterance, output_memory = self.attention( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + memory=memory, + attention_mask=attention_mask, + pos_emb=pos_emb, + ) + + return output_right_context_utterance, output_memory + + def _apply_attention_module_infer( + self, + right_context_utterance: torch.Tensor, + R: int, + lengths: torch.Tensor, + memory: torch.Tensor, + pos_emb: torch.Tensor, + state: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """Apply attention module in inference mode. + 1) Unpack cached states including: + - memory from previous chunks in the lower layer; + - attention key and value of left context from proceeding + chunk's compuation; + 2) Apply attention computation; + 3) Pack updated states including: + - output memory of current chunk in the lower layer; + - attention key and value in current chunk's computation, which would + be resued in next chunk's computation. + - length of current chunk. + """ + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + + if state is None: + state = self._init_state(utterance.size(1), device=utterance.device) + pre_memory, left_context_key, left_context_val = self._unpack_state( + state + ) + if self.use_memory: + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) + summary = summary[:1] + else: + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + # pos_emb is of shape [PE, D], where PE = L + 2 * U - 1, + # for query of [utterance] (i), key-value [left_context, utterance] (j), + # the max relative distance i - j is L + U - 1 + # the min relative distance i - j is -(U - 1) + L = left_context_key.size(0) # L <= left_context_length + U = utterance.size(0) + PE = L + 2 * U - 1 + tot_PE = self.left_context_length + 2 * U - 1 + assert pos_emb.size(0) == tot_PE + pos_emb = pos_emb[tot_PE - PE :] + ( + output_right_context_utterance, + output_memory, + next_key, + next_val, + ) = self.attention.infer( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + memory=pre_memory, + left_context_key=left_context_key, + left_context_val=left_context_val, + pos_emb=pos_emb, + ) + state = self._pack_state( + next_key, next_val, utterance.size(0), memory, state + ) + return output_right_context_utterance, output_memory, state + + def forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + pos_emb: torch.Tensor, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Forward pass for training and validation mode. + + B: batch size; + D: embedding dimension; + R: length of hard-copied right contexts; + U: length of full utterance; + M: length of memory vectors. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + It is an empty tensor without using memory. + attention_mask (torch.Tensor): + Attention mask for underlying attention module, + with shape (Q, KV), where Q = R + U + S, KV = M + R + U. + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + For training mode, P = 2*U-1. + + Returns: + A tuple containing 3 tensors: + - output utterance, with shape (U, B, D). + - output right context, with shape (R, B, D). + - output memory, with shape (M, B, D). + """ + R = right_context.size(0) + src = torch.cat([right_context, utterance]) + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # emformer attention module + src_att, output_memory = self._apply_attention_module_forward( + src, R, lengths, memory, pos_emb, attention_mask + ) + src = src + self.dropout(src_att) + + # convolution module + src_conv = self._apply_conv_module_forward(src, R) + src = src + self.dropout(src_conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + output_utterance = src[R:] + output_right_context = src[:R] + return output_utterance, output_right_context, output_memory + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + pos_emb: torch.Tensor, + state: Optional[List[torch.Tensor]] = None, + conv_cache: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: + """Forward pass for inference. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + M: length of memory. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + state (List[torch.Tensor], optional): + List of tensors representing layer internal state generated in + preceding computation. (default=None) + pos_emb (torch.Tensor): + Position encoding embedding, with shape (PE, D). + For infer mode, PE = L+2*U-1. + conv_cache (torch.Tensor, optional): + Cache tensor of left context for causal convolution. + + Returns: + (Tensor, Tensor, List[torch.Tensor], Tensor): + - output utterance, with shape (U, B, D); + - output right_context, with shape (R, B, D); + - output memory, with shape (1, B, D) or (0, B, D). + - output state. + - updated conv_cache. + """ + R = right_context.size(0) + src = torch.cat([right_context, utterance]) + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # emformer attention module + ( + src_att, + output_memory, + output_state, + ) = self._apply_attention_module_infer( + src, R, lengths, memory, pos_emb, state + ) + src = src + self.dropout(src_att) + + # convolution module + src_conv, conv_cache = self._apply_conv_module_infer(src, R, conv_cache) + src = src + self.dropout(src_conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + output_utterance = src[R:] + output_right_context = src[:R] + return ( + output_utterance, + output_right_context, + output_memory, + output_state, + conv_cache, + ) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py index 4549dad22..d8913ef74 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -154,9 +154,131 @@ def test_convolution_module_infer(): assert new_cache.shape == (B, D, kernel_size - 1) +def test_emformer_encoder_layer_forward(): + from emformer import EmformerEncoderLayer + + B, D = 2, 256 + chunk_length = 8 + right_context_length = 2 + left_context_length = 8 + kernel_size = 31 + num_chunks = 3 + U = num_chunks * chunk_length + R = num_chunks * right_context_length + + for use_memory in [True, False]: + if use_memory: + S = num_chunks + M = S - 1 + else: + S, M = 0, 0 + + layer = EmformerEncoderLayer( + d_model=D, + nhead=8, + dim_feedforward=1024, + chunk_length=chunk_length, + cnn_module_kernel=kernel_size, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=M, + ) + + Q, KV = R + U + S, M + R + U + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + memory = torch.randn(M, B, D) + attention_mask = torch.rand(Q, KV) >= 0.5 + PE = 2 * U - 1 + pos_emb = torch.randn(PE, D) + + output_utterance, output_right_context, output_memory = layer( + utterance, + lengths, + right_context, + memory, + attention_mask, + pos_emb, + ) + assert output_utterance.shape == (U, B, D) + assert output_right_context.shape == (R, B, D) + assert output_memory.shape == (M, B, D) + + +def test_emformer_encoder_layer_infer(): + from emformer import EmformerEncoderLayer + + B, D = 2, 256 + chunk_length = 8 + right_context_length = 2 + left_context_length = 8 + kernel_size = 31 + num_chunks = 1 + U = num_chunks * chunk_length + R = num_chunks * right_context_length + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + + layer = EmformerEncoderLayer( + d_model=D, + nhead=8, + dim_feedforward=1024, + chunk_length=chunk_length, + cnn_module_kernel=kernel_size, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=M, + ) + + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + memory = torch.randn(M, B, D) + state = None + PE = left_context_length + 2 * U - 1 + pos_emb = torch.randn(PE, D) + conv_cache = None + ( + output_utterance, + output_right_context, + output_memory, + output_state, + conv_cache, + ) = layer.infer( + utterance, + lengths, + right_context, + memory, + pos_emb, + state, + conv_cache, + ) + assert output_utterance.shape == (U, B, D) + assert output_right_context.shape == (R, B, D) + if use_memory: + assert output_memory.shape == (1, B, D) + else: + assert output_memory.shape == (0, B, D) + assert len(output_state) == 4 + assert output_state[0].shape == (M, B, D) + assert output_state[1].shape == (left_context_length, B, D) + assert output_state[2].shape == (left_context_length, B, D) + assert output_state[3].shape == (1, B) + assert conv_cache.shape == (B, D, kernel_size - 1) + + if __name__ == "__main__": test_rel_positional_encoding() test_emformer_attention_forward() test_emformer_attention_infer() test_convolution_module_forward() test_convolution_module_infer() + test_emformer_encoder_layer_forward() + test_emformer_encoder_layer_infer()