From 648a0b37d5044b7b0e67b9dd98e9e90e83a97a64 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 14 May 2022 23:18:16 +0800 Subject: [PATCH] add Emformer module --- .../emformer.py | 640 +++++++++++++++++- .../test_emformer.py | 269 +++++++- 2 files changed, 906 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 1e551bc39..f95072970 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -1303,7 +1303,6 @@ class EmformerEncoderLayer(nn.Module): output_right_context = src[:R] return output_utterance, output_right_context, output_memory - @torch.jit.export def infer( self, utterance: torch.Tensor, @@ -1383,3 +1382,642 @@ class EmformerEncoderLayer(nn.Module): output_state, conv_cache, ) + + +def _gen_attention_mask_block( + col_widths: List[int], + col_mask: List[bool], + num_rows: int, + device: torch.device, +) -> torch.Tensor: + assert len(col_widths) == len( + col_mask + ), "Length of col_widths must match that of col_mask" + + mask_block = [ + torch.ones(num_rows, col_width, device=device) + if is_ones_col + else torch.zeros(num_rows, col_width, device=device) + for col_width, is_ones_col in zip(col_widths, col_mask) + ] + return torch.cat(mask_block, dim=1) + + +class EmformerEncoder(nn.Module): + """Implements the Emformer architecture introduced in + *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency + Streaming Speech Recognition* + [:footcite:`shi2021emformer`]. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads in each emformer layer. + dim_feedforward (int): + Hidden layer dimension of each emformer layer's feedforward network. + num_encoder_layers (int): + Number of emformer layers to instantiate. + chunk_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (default: 0.0) + layer_dropout (float, optional): + Layer dropout probability. (default: 0.0) + cnn_module_kernel (int): + Kernel size of convolution module. + left_context_length (int, optional): + Length of left context. (default: 0) + right_context_length (int, optional): + Length of right context. (default: 0) + max_memory_size (int, optional): + Maximum number of memory elements to use. (default: 0) + tanh_on_mem (bool, optional): + If ``true``, applies tanh to memory elements. (default: ``false``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (default: -1e8) + """ + + def __init__( + self, + chunk_length: int, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.use_memory = max_memory_size > 0 + self.init_memory_op = nn.AvgPool1d( + kernel_size=chunk_length, + stride=chunk_length, + ceil_mode=True, + ) + + self.emformer_layers = nn.ModuleList( + [ + EmformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + chunk_length=chunk_length, + dropout=dropout, + layer_dropout=layer_dropout, + cnn_module_kernel=cnn_module_kernel, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=max_memory_size, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + for layer_idx in range(num_encoder_layers) + ] + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.chunk_length = chunk_length + self.max_memory_size = max_memory_size + + def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: + """Hard copy each chunk's right context and concat them.""" + T = x.shape[0] + num_chunks = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) + right_context_blocks = [] + for seg_idx in range(num_chunks - 1): + start = (seg_idx + 1) * self.chunk_length + end = start + self.right_context_length + right_context_blocks.append(x[start:end]) + right_context_blocks.append(x[T - self.right_context_length :]) + return torch.cat(right_context_blocks) + + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: + """Calculate column widths (key, value) in attention mask for the + chunk_idx chunk.""" + num_chunks = math.ceil(U / self.chunk_length) + rc = self.right_context_length + lc = self.left_context_length + rc_start = chunk_idx * rc + rc_end = rc_start + rc + chunk_start = max(chunk_idx * self.chunk_length - lc, 0) + chunk_end = min((chunk_idx + 1) * self.chunk_length, U) + R = rc * num_chunks + + if self.use_memory: + m_start = max(chunk_idx - self.max_memory_size, 0) + M = num_chunks - 1 + col_widths = [ + m_start, # before memory + chunk_idx - m_start, # memory + M - chunk_idx, # after memory + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + else: + col_widths = [ + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + + return col_widths + + def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor: + """Generate attention mask to simulate underlying chunk-wise attention + computation, where chunk-wise connections are filled with `False`, + and other unnecessary connections beyond chunk are filled with `True`. + + R: length of hard-copied right contexts; + U: length of full utterance; + S: length of summary vectors; + M: length of memory vectors; + Q: length of attention query; + KV: length of attention key and value. + + The shape of attention mask is (Q, KV). + If self.use_memory is `True`: + query = [right_context, utterance, summary]; + key, value = [memory, right_context, utterance]; + Q = R + U + S, KV = M + R + U. + Otherwise: + query = [right_context, utterance] + key, value = [right_context, utterance] + Q = R + U, KV = R + U. + + Suppose: + c_i: chunk at index i; + r_i: right context that c_i can use; + l_i: left context that c_i can use; + m_i: past memory vectors from previous layer that c_i can use; + s_i: summary vector of c_i. + The target chunk-wise attention is: + c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key); + s_i (in query) -> l_i, c_i, r_i (in key). + """ + U = utterance.size(0) + num_chunks = math.ceil(U / self.chunk_length) + + right_context_mask = [] + utterance_mask = [] + summary_mask = [] + + if self.use_memory: + num_cols = 9 + # right context and utterance both attend to memory, right context, + # utterance + right_context_utterance_cols_mask = [ + idx in [1, 4, 7] for idx in range(num_cols) + ] + # summary attends to right context, utterance + summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)] + masks_to_concat = [right_context_mask, utterance_mask, summary_mask] + else: + num_cols = 6 + # right context and utterance both attend to right context and + # utterance + right_context_utterance_cols_mask = [ + idx in [1, 4] for idx in range(num_cols) + ] + summary_cols_mask = None + masks_to_concat = [right_context_mask, utterance_mask] + + for chunk_idx in range(num_chunks): + col_widths = self._gen_attention_mask_col_widths(chunk_idx, U) + + right_context_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + self.right_context_length, + utterance.device, + ) + right_context_mask.append(right_context_mask_block) + + utterance_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + min( + self.chunk_length, + U - chunk_idx * self.chunk_length, + ), + utterance.device, + ) + utterance_mask.append(utterance_mask_block) + + if summary_cols_mask is not None: + summary_mask_block = _gen_attention_mask_block( + col_widths, summary_cols_mask, 1, utterance.device + ) + summary_mask.append(summary_mask_block) + + attention_mask = ( + 1 - torch.cat([torch.cat(mask) for mask in masks_to_concat]) + ).to(torch.bool) + return attention_mask + + def forward( + self, x: torch.Tensor, lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and validation mode. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (U + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, which contains the + right_context at the end. + + Returns: + A tuple of 2 tensors: + - output utterance frames, with shape (U, B, D). + - output_lengths, with shape (B,), without containing the + right_context at the end. + """ + U = x.size(0) - self.right_context_length + x, pos_emb = self.encoder_pos(x, pos_len=U, neg_len=U) + + right_context = self._gen_right_context(x) + utterance = x[:U] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) + attention_mask = self._gen_attention_mask(utterance) + memory = ( + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] + if self.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + + output = utterance + for layer in self.emformer_layers: + output, right_context, memory = layer( + output, + output_lengths, + right_context, + memory, + attention_mask, + pos_emb, + ) + + return output, output_lengths + + def infer( + self, + x: torch.Tensor, + lengths: torch.Tensor, + states: Optional[List[List[torch.Tensor]]] = None, + conv_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[ + torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor] + ]: + """Forward pass for streaming inference. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (U + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, which contains the + right_context at the end. + states (List[List[torch.Tensor]], optional): + Cached states from proceeding chunk's computation, where each + element (List[torch.Tensor]) corresponds to each emformer layer. + (default: None) + conv_caches (List[torch.Tensor], optional): + Cached tensors of left context for causal convolution, where each + element (Tensor) corresponds to each convolutional layer. + + Returns: + (Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]): + - output utterance frames, with shape (U, B, D). + - output lengths, with shape (B,), without containing the + right_context at the end. + - updated states from current chunk's computation. + - updated convolution caches from current chunk. + """ + assert x.size(0) == self.chunk_length + self.right_context_length, ( + "Per configured chunk_length and right_context_length, " + f"expected size of {self.chunk_length + self.right_context_length} " + f"for dimension 1 of x, but got {x.size(1)}." + ) + + pos_len = self.chunk_length + self.left_context_length + neg_len = self.chunk_length + x, pos_emb = self.encoder_pos(x, pos_len=pos_len, neg_len=neg_len) + + right_context_start_idx = x.size(0) - self.right_context_length + right_context = x[right_context_start_idx:] + utterance = x[:right_context_start_idx] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) + memory = ( + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + if self.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + output = utterance + output_states: List[List[torch.Tensor]] = [] + output_conv_caches: List[torch.Tensor] = [] + for layer_idx, layer in enumerate(self.emformer_layers): + ( + output, + right_context, + memory, + output_state, + output_conv_cache, + ) = layer.infer( + output, + output_lengths, + right_context, + memory, + pos_emb, + None if states is None else states[layer_idx], + None if conv_caches is None else conv_caches[layer_idx], + ) + output_states.append(output_state) + output_conv_caches.append(output_conv_cache) + + return output, output_lengths, output_states, output_conv_caches + + +class Emformer(EncoderInterface): + def __init__( + self, + num_features: int, + chunk_length: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 3, + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.subsampling_factor = subsampling_factor + self.right_context_length = right_context_length + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + if chunk_length % 4 != 0: + raise NotImplementedError("chunk_length must be a mutiple of 4.") + if left_context_length != 0 and left_context_length % 4 != 0: + raise NotImplementedError( + "left_context_length must be 0 or a mutiple of 4." + ) + if right_context_length != 0 and right_context_length % 4 != 0: + raise NotImplementedError( + "right_context_length must be 0 or a mutiple of 4." + ) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder = EmformerEncoder( + chunk_length=chunk_length // 4, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + dropout=dropout, + layer_dropout=layer_dropout, + cnn_module_kernel=cnn_module_kernel, + left_context_length=left_context_length // 4, + right_context_length=right_context_length // 4, + max_memory_size=max_memory_size, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and non-streaming inference. + + B: batch size; + D: feature dimension; + T: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, T, D). + x_lens (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, containing the + right_context at the end. + + Returns: + (Tensor, Tensor): + - output embedding, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. + - output lengths, with shape (B,), without containing the + right_context at the end. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == x_lens.max().item() + + output, output_lengths = self.encoder(x, x_lens) # (T, N, C) + + output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + + return output, output_lengths + + @torch.jit.export + def infer( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[List[List[torch.Tensor]]] = None, + conv_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + """Forward pass for streaming inference. + + B: batch size; + D: feature dimension; + T: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, T, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, containing the + right_context at the end. + states (List[List[torch.Tensor]], optional): + Cached states from proceeding chunk's computation, where each + element (List[torch.Tensor]) corresponds to each emformer layer. + (default: None) + conv_caches (List[torch.Tensor], optional): + Cached tensors of left context for causal convolution, where each + element (Tensor) corresponds to each convolutional layer. + Returns: + (Tensor, Tensor): + - output embedding, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. + - output lengths, with shape (B,), without containing the + right_context at the end. + - updated states from current chunk's computation. + - updated convolution caches from current chunk. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == x_lens.max().item() + + ( + output, + output_lengths, + output_states, + output_conv_caches, + ) = self.encoder.infer(x, x_lens, states, conv_caches) + + output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + + return output, output_lengths, output_states, output_conv_caches + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py index d8913ef74..f0a543327 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -113,7 +113,10 @@ def test_convolution_module_forward(): R = num_chunks * right_context_length kernel_size = 31 conv_module = ConvolutionModule( - chunk_length, right_context_length, D, kernel_size, + chunk_length, + right_context_length, + D, + kernel_size, ) utterance = torch.randn(U, B, D) @@ -139,7 +142,10 @@ def test_convolution_module_infer(): R = num_chunks * right_context_length kernel_size = 31 conv_module = ConvolutionModule( - chunk_length, right_context_length, D, kernel_size, + chunk_length, + right_context_length, + D, + kernel_size, ) utterance = torch.randn(U, B, D) @@ -274,6 +280,260 @@ def test_emformer_encoder_layer_infer(): assert conv_cache.shape == (B, D, kernel_size - 1) +def test_emformer_encoder_forward(): + from emformer import EmformerEncoder + + B, D = 2, 256 + chunk_length = 4 + right_context_length = 2 + left_context_length = 2 + num_chunks = 3 + U = num_chunks * chunk_length + kernel_size = 31 + num_encoder_layers = 2 + + for use_memory in [True, False]: + if use_memory: + S = num_chunks + M = S - 1 + else: + S, M = 0, 0 + + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=kernel_size, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=M, + ) + + x = torch.randn(U + right_context_length, B, D) + lengths = torch.randint(1, U + right_context_length + 1, (B,)) + lengths[0] = U + right_context_length + + output, output_lengths = encoder(x, lengths) + assert output.shape == (U, B, D) + assert torch.equal( + output_lengths, torch.clamp(lengths - right_context_length, min=0) + ) + + +def test_emformer_encoder_infer(): + from emformer import EmformerEncoder + + B, D = 2, 256 + num_encoder_layers = 2 + chunk_length = 4 + right_context_length = 2 + left_context_length = 2 + num_chunks = 3 + kernel_size = 31 + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=kernel_size, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=M, + ) + + states = None + conv_caches = None + for chunk_idx in range(num_chunks): + x = torch.randn(chunk_length + right_context_length, B, D) + lengths = torch.randint( + 1, chunk_length + right_context_length + 1, (B,) + ) + lengths[0] = chunk_length + right_context_length + output, output_lengths, states, conv_caches = encoder.infer( + x, lengths, states, conv_caches + ) + assert output.shape == (chunk_length, B, D) + assert torch.equal( + output_lengths, + torch.clamp(lengths - right_context_length, min=0), + ) + assert len(states) == num_encoder_layers + for state in states: + assert len(state) == 4 + assert state[0].shape == (M, B, D) + assert state[1].shape == (left_context_length, B, D) + assert state[2].shape == (left_context_length, B, D) + assert torch.equal( + state[3], + (chunk_idx + 1) * chunk_length * torch.ones_like(state[3]), + ) + for conv_cache in conv_caches: + assert conv_cache.shape == (B, D, kernel_size - 1) + + +def test_emformer_encoder_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + left_context_length, right_context_length = 1, 2 + D = 256 + num_encoder_layers = 3 + kernel_size = 31 + memory_sizes = [0, 3] + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=kernel_size, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=M, + ) + encoder.eval() + + x = torch.randn(U + right_context_length, 1, D) + lengths = torch.tensor([U + right_context_length]) + + # training mode with full utterance + forward_output, forward_output_lengths = encoder(x, lengths) + + # streaming inference mode with individual chunks + states = None + conv_caches = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx : end_idx + right_context_length] # noqa + chunk_length = torch.tensor([chunk_length]) + ( + infer_output_chunk, + infer_output_lengths, + states, + conv_caches, + ) = encoder.infer(chunk, chunk_length, states, conv_caches) + forward_output_chunk = forward_output[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-4, + rtol=0.0, + ), ( + infer_output_chunk - forward_output_chunk + ) + + +def test_emformer_forward(): + from emformer import Emformer + + num_features = 80 + chunk_length = 16 + right_context_length = 8 + left_context_length = 8 + num_chunks = 3 + U = num_chunks * chunk_length + B, D = 2, 256 + kernel_size = 31 + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + cnn_module_kernel=kernel_size, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=M, + ) + x = torch.randn(B, U + right_context_length + 3, num_features) + x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,)) + x_lens[0] = U + right_context_length + 3 + output, output_lengths = model(x, x_lens) + assert output.shape == (B, U // 4, D) + assert torch.equal( + output_lengths, + torch.clamp( + ((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4, min=0 + ), + ) + + +def test_emformer_infer(): + from emformer import Emformer + + num_features = 80 + chunk_length = 8 + U = chunk_length + left_context_length, right_context_length = 128, 4 + B, D = 2, 256 + num_chunks = 3 + num_encoder_layers = 2 + kernel_size = 31 + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=kernel_size, + left_context_length=left_context_length, + right_context_length=right_context_length, + max_memory_size=M, + ) + states = None + conv_caches = None + for chunk_idx in range(num_chunks): + x = torch.randn(B, U + right_context_length + 3, num_features) + x_lens = torch.randint(1, U + right_context_length + 3 + 1, (B,)) + x_lens[0] = U + right_context_length + 3 + output, output_lengths, states, conv_caches = model.infer( + x, x_lens, states, conv_caches + ) + assert output.shape == (B, U // 4, D) + assert torch.equal( + output_lengths, + torch.clamp( + ((x_lens - 1) // 2 - 1) // 2 - right_context_length // 4, + min=0, + ), + ) + assert len(states) == num_encoder_layers + for state in states: + assert len(state) == 4 + assert state[0].shape == (M, B, D) + assert state[1].shape == (left_context_length // 4, B, D) + assert state[2].shape == (left_context_length // 4, B, D) + assert torch.equal( + state[3], + U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]), + ) + for conv_cache in conv_caches: + assert conv_cache.shape == (B, D, kernel_size - 1) + + if __name__ == "__main__": test_rel_positional_encoding() test_emformer_attention_forward() @@ -282,3 +542,8 @@ if __name__ == "__main__": test_convolution_module_infer() test_emformer_encoder_layer_forward() test_emformer_encoder_layer_infer() + test_emformer_encoder_forward() + test_emformer_encoder_infer() + test_emformer_encoder_forward_infer_consistency() + test_emformer_forward() + test_emformer_infer()